Source code for zea.models.taesd

"""Tiny Autoencoder (TAESD) model converted to Tensorflow.

For the original implementation, see the `TAESD repository <https://github.com/madebyollin/taesd>`_.

You can see an example of how to use this model in the example notebook:
:doc:`../notebooks/models/taesd_autoencoder_example`.

"""

from pathlib import Path

import keras
from keras import backend, ops

from zea.backend import _import_tf
from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.preset_utils import get_preset_loader, register_presets
from zea.models.presets import taesdxl_decoder_presets, taesdxl_encoder_presets, taesdxl_presets

tf = _import_tf()


[docs] @model_registry(name="taesdxl") class TinyAutoencoder(BaseModel): """[TAESD](https://github.com/madebyollin/taesd) model in TensorFlow.""" def __init__(self, **kwargs): """ Initializes the TAESD model with the given parameters. Args: **kwargs: Additional keyword arguments to pass to the superclass initializer. """ if backend.backend() not in ["tensorflow", "jax"]: raise NotImplementedError( "TinyDecoder is only currently supported with the TensorFlow or Jax backend." ) assert tf is not None, ( "TensorFlow is not installed. Please install TensorFlow to use TinyAutoencoder. " "This is required even if you are using the Jax backend, the model is built " "using TensorFlow." ) _fix_tf_to_jax_resize_nearest_neighbor() super().__init__(**kwargs) self.encoder = TinyEncoder() self.decoder = TinyDecoder() self._grayscale = False
[docs] def encode(self, inputs): """Encode the input images. Args: inputs (tensor): Input images of shape (batch_size, height, width, channels). """ if self.encoder.network is None or self.decoder.network is None: raise ValueError( "Please load model using `TinyAutoencoder.from_preset()` before calling." ) if ops.shape(inputs)[-1] == 1: self._grayscale = True inputs = ops.concatenate([inputs, inputs, inputs], axis=-1) # grayscale to RGB return self.encoder(inputs)
[docs] def decode(self, inputs): """Decode the encoded images. Args: inputs (tensor): Input images of shape (batch_size, height, width, 4). """ decoded = self.decoder(inputs) if self._grayscale: decoded = ops.image.rgb_to_grayscale(decoded, data_format="channels_last") return decoded
[docs] def call(self, inputs): """Applies the full autoencoder to the input.""" encoded = self.encode(inputs) # NOTE: Here you can compress the encoding a little bit more by going # to uint8 like in the original model # https://github.com/huggingface/diffusers/blob/cd30820/src/diffusers/models/autoencoders/autoencoder_tiny.py?plain=1#L336-L342 # noqa: E501 decoded = self.decode(encoded) return decoded
[docs] def custom_load_weights(self, preset, **kwargs): """Load the weights for the encoder and decoder.""" self.encoder.custom_load_weights(preset) self.decoder.custom_load_weights(preset)
[docs] class TinyBase(BaseModel): """Base class for TAESD encoder and decoder.""" def __init__(self, tiny_type=None, **kwargs): # Assertions assert tiny_type in [ "encoder", "decoder", ], "Type must be either 'encoder' or 'decoder'." if backend.backend() not in ["tensorflow", "jax"]: raise NotImplementedError( f"{self.__class__.__name__} is only currently supported with the " "TensorFlow or Jax backend." ) super().__init__(**kwargs) self.network = None self.download_files = [ f"{tiny_type}/variables/variables.data-00000-of-00001", f"{tiny_type}/variables/variables.index", f"{tiny_type}/saved_model.pb", f"{tiny_type}/fingerprint.pb", ]
[docs] def build(self, input_shape): """Builds the network.""" self.maybe_convert_to_jax(input_shape)
[docs] def maybe_convert_to_jax(self, input_shape): """Converts the network to Jax if backend is Jax.""" if backend.backend() == "jax": inputs = ops.zeros(input_shape) from zea.backend import tf2jax jax_func, jax_params = tf2jax.convert(tf.function(self.network), inputs) def call_fn(params, state, rng, inputs, training): return jax_func(state, inputs) self.network = keras.layers.JaxLayer(call_fn, state=jax_params)
def _load_layer(self, path: Path | str): if backend.backend() == "tensorflow": return keras.layers.TFSMLayer(path, call_endpoint="serving_default") elif backend.backend() == "jax": return tf.saved_model.load(path) else: raise NotImplementedError( f"{self.__class__.__name__} is only currently supported with the " f"TensorFlow or Jax backend. You are using {backend.backend()}." )
[docs] def custom_load_weights(self, preset, **kwargs): """Load the weights for the encoder or decoder.""" loader = get_preset_loader(preset) for file in self.download_files: filename = loader.get_file(file) base_path = Path(filename).parent self.network = self._load_layer(base_path)
[docs] def call(self, inputs): """ Applies the network to the input. """ if self.network is None: raise ValueError( f"Please load model using `{self.__class__.__name__}.from_preset()` before calling." ) out = self.network(inputs) if backend.backend() == "tensorflow": # because decoded is dict, take first key out = out[next(iter(out))] return out
[docs] @model_registry(name="taesdxl_encoder") class TinyEncoder(TinyBase): """Encoder from TAESD model.""" def __init__(self, **kwargs): """ Initializes the TAESD encoder. Args: **kwargs: Additional keyword arguments passed to the superclass initializer. """ super().__init__(tiny_type="encoder", **kwargs)
[docs] @model_registry(name="taesdxl_decoder") class TinyDecoder(TinyBase): """Decoder from TAESD model.""" def __init__(self, **kwargs): """ Initializes the TAESD decoder. Args: **kwargs: Additional keyword arguments passed to the superclass initializer. """ super().__init__(tiny_type="decoder", **kwargs)
def _fix_tf_to_jax_resize_nearest_neighbor(): # This block of code is used to allow the Jax backend to work with TAESD # It overrides the ResizeNearestNeighbor op to allow align_corners=True # and half_pixel_centers=True. This means outputs of the jax model might # not be a 100% match to the tensorflow model if backend.backend() != "jax": return import jax import jax.numpy as jnp from zea.backend import tf2jax def _resize_nearest_neighbor(proto): """Parse a ResizeNearestNeighbor op.""" tf2jax._src.ops._check_attrs(proto, {"T", "align_corners", "half_pixel_centers"}) def _func(images: jnp.ndarray, size: jnp.ndarray) -> jnp.ndarray: if len(images.shape) != 4: raise ValueError( "Expected A 4D tensor with shape [batch, height, width, channels], " f"found {images.shape}" ) inp_batch, _, _, inp_channels = images.shape out_height, out_width = size.tolist() return jax.image.resize( images, shape=(inp_batch, out_height, out_width, inp_channels), method=jax.image.ResizeMethod.NEAREST, ) return _func # hack to allow align_corners=True and half_pixel_centers=True tf2jax._src.ops._jax_ops["ResizeNearestNeighbor"] = _resize_nearest_neighbor register_presets(taesdxl_presets, TinyAutoencoder) register_presets(taesdxl_encoder_presets, TinyEncoder) register_presets(taesdxl_decoder_presets, TinyDecoder)