zea.models.unet

UNet models and architectures

Functions

get_time_conditional_unetwork(image_shape[, ...])

Get a basic UNet architecture with time-conditional sinusoidal embeddings

get_unetwork(image_shape, widths, block_depth)

Get a basic UNet architecture

Classes

UNet(*args, **kwargs)

UNet model

UNetTimeConditional(*args, **kwargs)

UNet model with time-conditional sinusoidal embedding

class zea.models.unet.UNet(*args, **kwargs)[source]

Bases: BaseModel

UNet model

call(*args, **kwargs)[source]
get_config()[source]

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

class zea.models.unet.UNetTimeConditional(*args, **kwargs)[source]

Bases: BaseModel

UNet model with time-conditional sinusoidal embedding

call(*args, **kwargs)[source]
get_config()[source]

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

zea.models.unet.get_time_conditional_unetwork(image_shape, widths=None, block_depth=None, embedding_min_frequency=1.0, embedding_max_frequency=1000.0, embedding_dims=32)[source]

Get a basic UNet architecture with time-conditional sinusoidal embeddings

Used in Diffusion Models.

Parameters:
  • image_shape – tuple, (height, width, channels)

  • widths – list, number of filters in each layer

  • block_depth – int, number of residual blocks in each down/up block

  • embedding_min_frequency – float, minimum frequency for sinusoidal embeddings

  • embedding_max_frequency – float, maximum frequency for sinusoidal embeddings

  • embedding_dims – int, number of dimensions for sinusoidal embeddings

Returns:

keras.Model

zea.models.unet.get_unetwork(image_shape, widths, block_depth)[source]

Get a basic UNet architecture

Parameters:
  • image_shape – tuple, (height, width, channels)

  • widths – list, number of filters in each layer

  • block_depth – int, number of residual blocks in each down/up block

Returns:

keras.Model