zea.models.base

Base model class for all zea Keras models.

This module provides the BaseModel class for all zea Keras models.

Functions

deserialize_zea_object(config)

Retrieve the object by deserializing the config dict.

Classes

BaseModel(*args, **kwargs)

Base class for all zea Keras models.

class zea.models.base.BaseModel(*args, **kwargs)[source]

Bases: Model

Base class for all zea Keras models.

A BaseModel is the basic model for zea.

classmethod from_config(config)[source]

Create a model instance from a configuration dictionary.

The default from_config() for functional models will return a vanilla keras.Model. This override ensures a subclass instance is returned.

Parameters:

config (dict) – Configuration dictionary.

Returns:

An instance of the model subclass.

Return type:

BaseModel

classmethod from_preset(preset, load_weights=True, **kwargs)[source]

Instantiate a model from a preset.

A preset is a directory of configs, weights, and other file assets used to save and load a pre-trained model. The preset can be passed as one of:

  1. a built-in preset identifier like 'bert_base_en'

  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'

  3. a Hugging Face handle like 'hf://user/bert_base_en'

  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways: either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

Parameters:
  • preset (str) – A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.

  • load_weights (bool) – If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.

  • **kwargs – Additional keyword arguments.

Examples

# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
    "gemma_2b_en",
)

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
    "bert_base_en",
    load_weights=False,
)
Returns:

The loaded model instance.

Return type:

BaseModel

presets = {}
save_to_preset(preset_dir)[source]

Save backbone to a preset directory.

Parameters:

preset_dir – The path to the local model preset directory.

zea.models.base.deserialize_zea_object(config)[source]

Retrieve the object by deserializing the config dict.

Need to borrow this function from keras and customize a bit to allow deserialization of custom (zea) objects. See the original function here: keras.utils.deserialize_keras_object(). As from the following keras PR did not work on none Keras objects anymore: - https://github.com/keras-team/keras/pull/20751

Parameters:

config (dict) – The configuration dictionary

Returns:

The deserialized object

Return type:

obj (Object)