Source code for zea.models.dense

"""Dense models and architectures"""

import keras
from keras import layers

from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.layers import sinusoidal_embedding
from zea.models.preset_utils import register_presets
from zea.models.presets import dense_presets


[docs] @model_registry(name="dense") class DenseNet(BaseModel): """Simple dense model""" def __init__( self, input_dim, widths, output_dim, name="dense", **kwargs, ): super().__init__(name=name, **kwargs) self.input_dim = input_dim self.widths = widths self.output_dim = output_dim self.network = get_dense_network(self.input_dim, self.widths, self.output_dim)
[docs] def get_config(self): config = super().get_config() config.update( { "input_dim": self.input_dim, "widths": self.widths, "output_dim": self.output_dim, } ) return config
[docs] def call(self, *args, **kwargs): return self.network(*args, **kwargs)
[docs] def get_dense_network(input_dim, widths, output_dim): """Simple feedforward network""" inputs = keras.Input(shape=(input_dim,)) x = inputs for width in widths: x = layers.Dense(width, activation="relu")(x) outputs = layers.Dense(output_dim, kernel_initializer="zeros")(x) return keras.Model(inputs, outputs, name="dense_net")
[docs] @model_registry(name="dense_time_conditional") class DenseTimeConditionalNet(BaseModel): """Dense model with time-conditional sinusoidal embedding""" def __init__( self, input_dim, widths, output_dim, embedding_min_frequency=1.0, embedding_max_frequency=1000.0, embedding_dims=32, name="dense_time_conditional", **kwargs, ): super().__init__(name=name, **kwargs) self.input_dim = input_dim self.widths = widths self.output_dim = output_dim self.embedding_min_frequency = embedding_min_frequency self.embedding_max_frequency = embedding_max_frequency self.embedding_dims = embedding_dims self.network = get_time_conditional_dense_network( self.input_dim, self.widths, self.output_dim, self.embedding_min_frequency, self.embedding_max_frequency, self.embedding_dims, )
[docs] def get_config(self): config = super().get_config() config.update( { "input_dim": self.input_dim, "widths": self.widths, "output_dim": self.output_dim, "embedding_min_frequency": self.embedding_min_frequency, "embedding_max_frequency": self.embedding_max_frequency, "embedding_dims": self.embedding_dims, } ) return config
[docs] def call(self, *args, **kwargs): return self.network(*args, **kwargs)
[docs] def get_time_conditional_dense_network( input_dim, widths, output_dim, embedding_min_frequency=1.0, embedding_max_frequency=1000.0, embedding_dims=32, ): """Dense network with time-conditional sinusoidal embedding""" inputs = keras.Input(shape=(input_dim,)) time_input = keras.Input(shape=(1,)) @keras.saving.register_keras_serializable() def _sinusoidal_embedding(x): return sinusoidal_embedding( x, embedding_min_frequency, embedding_max_frequency, embedding_dims ) e = layers.Lambda(_sinusoidal_embedding, output_shape=(embedding_dims,))(time_input) x = layers.Concatenate()([inputs, e]) for width in widths: x = layers.Dense(width, activation="relu")(x) outputs = layers.Dense(output_dim, kernel_initializer="zeros")(x) return keras.Model([inputs, time_input], outputs, name="dense_time_conditional_net")
register_presets(dense_presets, DenseNet) register_presets(dense_presets, DenseTimeConditionalNet)