"""UNet models and architectures"""
import keras
from keras import layers
from zea import log
from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.layers import DownBlock, ResidualBlock, UpBlock, sinusoidal_embedding
from zea.models.preset_utils import register_presets
from zea.models.presets import unet_presets
[docs]
@model_registry(name="unet")
class UNet(BaseModel):
"""UNet model"""
def __init__(
self,
input_shape,
widths,
block_depth,
input_range,
name="unet",
**kwargs,
):
"""Initializes a UNet model"""
super().__init__(name=name, **kwargs)
self.input_shape = input_shape
self.input_range = input_range
self.widths = widths
self.block_depth = block_depth
self.network = get_unetwork(self.input_shape, self.widths, self.block_depth)
[docs]
def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape,
"input_range": self.input_range,
"widths": self.widths,
"block_depth": self.block_depth,
}
)
return config
[docs]
def call(self, *args, **kwargs):
return self.network(*args, **kwargs)
[docs]
def get_unetwork(
image_shape,
widths,
block_depth,
):
"""Get a basic UNet architecture
Args:
image_shape: tuple, (height, width, channels)
widths: list, number of filters in each layer
block_depth: int, number of residual blocks in each down/up block
Returns:
keras.Model
"""
assert len(image_shape) == 3, "image_shape must be a tuple of (height, width, channels)"
image_height, image_width, n_channels = image_shape
noisy_images = keras.Input(shape=(image_height, image_width, n_channels))
x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
skips = []
for width in widths[:-1]:
x = DownBlock(width, block_depth)([x, skips])
for _ in range(block_depth):
x = ResidualBlock(widths[-1])(x)
for width in reversed(widths[:-1]):
x = UpBlock(width, block_depth)([x, skips])
x = layers.Conv2D(n_channels, kernel_size=1, kernel_initializer="zeros")(x)
return keras.Model(noisy_images, x, name="residual_unet")
[docs]
@model_registry(name="unet_time_conditional")
class UNetTimeConditional(BaseModel):
"""UNet model with time-conditional sinusoidal embedding"""
def __init__(
self,
image_shape,
widths,
block_depth,
image_range,
embedding_min_frequency=1.0,
embedding_max_frequency=1000.0,
embedding_dims=32,
name="unet_time_conditional",
**kwargs,
):
super().__init__(name=name, **kwargs)
self.image_shape = image_shape
self.image_range = image_range
self.widths = widths
self.block_depth = block_depth
self.embedding_min_frequency = embedding_min_frequency
self.embedding_max_frequency = embedding_max_frequency
self.embedding_dims = embedding_dims
self.network = get_time_conditional_unetwork(
self.image_shape,
self.widths,
self.block_depth,
self.embedding_min_frequency,
self.embedding_max_frequency,
self.embedding_dims,
)
[docs]
def get_config(self):
config = super().get_config()
config.update(
{
"image_shape": self.image_shape,
"image_range": self.image_range,
"widths": self.widths,
"block_depth": self.block_depth,
"embedding_min_frequency": self.embedding_min_frequency,
"embedding_max_frequency": self.embedding_max_frequency,
"embedding_dims": self.embedding_dims,
}
)
return config
[docs]
def call(self, *args, **kwargs):
return self.network(*args, **kwargs)
[docs]
def get_time_conditional_unetwork(
image_shape,
widths=None,
block_depth=None,
embedding_min_frequency=1.0,
embedding_max_frequency=1000.0,
embedding_dims=32,
):
"""Get a basic UNet architecture with time-conditional sinusoidal embeddings
Used in Diffusion Models.
Args:
image_shape: tuple, (height, width, channels)
widths: list, number of filters in each layer
block_depth: int, number of residual blocks in each down/up block
embedding_min_frequency: float, minimum frequency for sinusoidal embeddings
embedding_max_frequency: float, maximum frequency for sinusoidal embeddings
embedding_dims: int, number of dimensions for sinusoidal embeddings
Returns:
keras.Model
"""
assert len(image_shape) == 3, "image_shape must be a tuple of (height, width, channels)"
if widths is None:
log.warning("No widths provided, using default widths [32, 64, 96, 128]")
widths = [32, 64, 96, 128]
if block_depth is None:
log.warning("No block_depth provided, using default block_depth 2")
block_depth = 2
image_height, image_width, n_channels = image_shape
noisy_images = keras.Input(shape=(image_height, image_width, n_channels))
noise_variances = keras.Input(shape=(1, 1, 1))
@keras.saving.register_keras_serializable()
def _sinusoidal_embedding(x):
return sinusoidal_embedding(
x, embedding_min_frequency, embedding_max_frequency, embedding_dims
)
e = layers.Lambda(_sinusoidal_embedding, output_shape=(1, 1, 32))(noise_variances)
e = layers.UpSampling2D(size=(image_height, image_width), interpolation="nearest")(e)
x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
x = layers.Concatenate()([x, e])
skips = []
for width in widths[:-1]:
x = DownBlock(width, block_depth)([x, skips])
for _ in range(block_depth):
x = ResidualBlock(widths[-1])(x)
for width in reversed(widths[:-1]):
x = UpBlock(width, block_depth)([x, skips])
x = layers.Conv2D(n_channels, kernel_size=1, kernel_initializer="zeros")(x)
return keras.Model([noisy_images, noise_variances], x, name="residual_unet")
register_presets(unet_presets, UNet)
register_presets(unet_presets, UNetTimeConditional)