Source code for zea.models.layers

"""Layers used in zea.models"""

import math

import keras
from keras import layers, ops


[docs] @keras.saving.register_keras_serializable() def sinusoidal_embedding(x, embedding_min_frequency, embedding_max_frequency, embedding_dims): """Sinusoidal embedding layer.""" frequencies = ops.exp( ops.linspace( ops.log(embedding_min_frequency), ops.log(embedding_max_frequency), embedding_dims // 2, ) ) angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32") embeddings = ops.concatenate( [ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=-1 ) return embeddings
[docs] def ResidualBlock(width): """Residual block with swish activation.""" def apply(x): input_width = ops.shape(x)[3] if input_width == width: residual = x else: residual = layers.Conv2D(width, kernel_size=1)(x) x = layers.BatchNormalization(center=False, scale=False)(x) x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x) x = layers.Conv2D(width, kernel_size=3, padding="same")(x) x = layers.Add()([x, residual]) return x return apply
[docs] def DownBlock(width, block_depth): """Downsampling block with residual connections.""" def apply(x): x, skips = x for _ in range(block_depth): x = ResidualBlock(width)(x) skips.append(x) x = layers.AveragePooling2D(pool_size=2)(x) return x return apply
[docs] def UpBlock(width, block_depth): """Upsampling block with residual connections.""" def apply(x): x, skips = x x = layers.UpSampling2D(size=2, interpolation="bilinear")(x) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) x = ResidualBlock(width)(x) return x return apply