Source code for zea.models.preset_utils

"""Mostly from keras_hub.src.models import preset_utils"""

import collections
import datetime
import json
import os
from pathlib import Path

import huggingface_hub
import keras
from huggingface_hub.utils import EntryNotFoundError, HFValidationError

import zea
from zea.internal.cache import ZEA_CACHE_DIR
from zea.internal.registry import model_registry

HF_PREFIX = "hf://"

HF_SCHEME = "hf"

ASSET_DIR = "assets"

# Config file names.
CONFIG_FILE = "config.json"
IMAGE_CONVERTER_CONFIG_FILE = "image_converter.json"
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
METADATA_FILE = "metadata.json"

# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"

# HuggingFace filenames.
README_FILE = "README.md"
HF_CONFIG_FILE = "config.json"

HF_MODELS_DIR = ZEA_CACHE_DIR / "huggingface" / "models"
HF_MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Global state for preset registry.
BUILTIN_PRESETS = {}
BUILTIN_PRESETS_FOR_MODEL = collections.defaultdict(dict)


[docs] def register_presets(presets, model_cls): """Register built-in presets for a set of classes. Note that this is intended only for models and presets shipped in the library itself. """ for preset in presets: BUILTIN_PRESETS[preset] = presets[preset] BUILTIN_PRESETS_FOR_MODEL[model_cls][preset] = presets[preset]
[docs] def builtin_presets(cls): """Find all registered built-in presets for a class.""" presets = {} if cls in BUILTIN_PRESETS_FOR_MODEL: presets.update(BUILTIN_PRESETS_FOR_MODEL[cls]) return presets
[docs] def get_file(preset, path): """Download a preset file in necessary and return the local path.""" if not isinstance(preset, str): raise ValueError(f"A preset identifier must be a string. Received: preset={preset}") if preset in BUILTIN_PRESETS: if "hf_handle" in BUILTIN_PRESETS[preset]: preset = BUILTIN_PRESETS[preset]["hf_handle"] else: preset = BUILTIN_PRESETS[preset]["path"] scheme = None if "://" in preset: scheme = preset.split("://")[0].lower() if scheme == HF_SCHEME: if huggingface_hub is None: raise ImportError( f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " "Please install with `pip install huggingface_hub`." ) hf_handle = preset.removeprefix(HF_SCHEME + "://") def _download_from_hf(repo_id, filename): return huggingface_hub.hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=HF_MODELS_DIR, ) try: # Try without login first return _download_from_hf(hf_handle, path) except huggingface_hub.utils.RepositoryNotFoundError: # Try to login and retry download huggingface_hub.login(new_session=False) return _download_from_hf(hf_handle, path) except HFValidationError as e: raise ValueError( "Unexpected Hugging Face preset. Hugging Face model handles " "should have the form 'hf://{org}/{model}'. For example, " f"'hf://username/bert_base_en'. Received: preset={preset}." ) from e except EntryNotFoundError as e: message = str(e) if message.find("403 Client Error"): raise FileNotFoundError( f"`{path}` doesn't exist in preset directory `{preset}`." ) from e raise ValueError(message) from e elif Path(preset).exists(): # Assume a local filepath local_path = Path(preset) / path if not local_path.exists(): raise FileNotFoundError(f"`{path}` doesn't exist in preset directory `{preset}`.") return str(local_path) else: raise ValueError( "Unknown preset identifier. A preset must be a one of:\n" "1) a built-in preset identifier like `'taesdxl'`\n" "2) a Hugging Face handle like `'hf://zea/taesdxl'`\n" "3) a path to a local preset directory like `'./taesdxl`\n" "Use `print(cls.presets.keys())` to view all built-in presets for " "API symbol `cls`.\n" f"Received: preset='{preset}'" )
[docs] def load_json(preset, config_file=CONFIG_FILE): """Load a JSON file from a preset.""" config_path = get_file(preset, config_file) with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) return config
[docs] def load_serialized_object(config, **kwargs): """Load a serialized Keras object from a config.""" # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`. # Ensure that `dtype` is properly configured. dtype = kwargs.pop("dtype", None) config = set_dtype_in_config(config, dtype) config["config"] = {**config["config"], **kwargs} # return keras.saving.deserialize_keras_object(config) return zea.models.base.deserialize_zea_object(config)
[docs] def check_config_class(config): """Validate a preset is being loaded on the correct class.""" registered_name = config["registered_name"] if registered_name in ("Functional", "Sequential"): return keras.Model # cls = keras.saving.get_registered_object(registered_name) name = keras_to_zea_registry(registered_name, model_registry) cls = model_registry[name] if cls is None: raise ValueError( f"Attempting to load class {registered_name} with " "`from_preset()`, but there is no class registered with zea " f"for {registered_name}. Make sure to register any custom " "classes with `zea.registry.model_registry()`." ) return cls
[docs] def jax_memory_cleanup(layer): """Cleanup memory for JAX models.""" # For jax, delete all previous allocated memory to avoid temporarily # duplicating variable allocations. torch and tensorflow have stateful # variable types and do not need this fix. if keras.config.backend() == "jax": for weight in layer.weights: if getattr(weight, "_value", None) is not None: weight._value.delete()
[docs] def set_dtype_in_config(config, dtype=None): """Set the `dtype` in a serialized Keras config.""" if dtype is None: return config config = config.copy() if "dtype" not in config["config"]: # Forward `dtype` to the config. config["config"]["dtype"] = dtype elif ( "dtype" in config["config"] and isinstance(config["config"]["dtype"], dict) and "DTypePolicyMap" in config["config"]["dtype"]["class_name"] ): # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default # policy. policy_map_config = config["config"]["dtype"]["config"] policy_map_config["default_policy"] = dtype for k in policy_map_config["policy_map"].keys(): policy_map_config["policy_map"][k]["config"]["source_name"] = dtype return config
[docs] def check_file_exists(preset, path): """Check if a file exists in a preset.""" try: get_file(preset, path) except FileNotFoundError: return False return True
def _assert_file_exists(preset, path): try: get_file(preset, path) except FileNotFoundError as e: raise ValueError( f"Preset {preset} has no {path}. Make sure the URL or " "directory you are trying to load is a valid KerasHub preset and " "and that you have permissions to read/download from this location." ) from e
[docs] def keras_to_zea_registry(keras_name, zea_registry): """Convert a Keras class name to a zea registry name.""" for registry_name, entry in zea_registry.registry.items(): if entry.__name__ == keras_name: return registry_name raise ValueError( f"Class {keras_name} not found in `zea` registry. " "Make sure to register any custom classes with `zea.registry.model_registry()`. " "Currently, the `zea` registry contains: " f"{zea_registry.registry.items()}" )
[docs] class PresetLoader: """Base class for loading a model from a preset.""" def __init__(self, preset, config): """Initialize a preset loader.""" self.config = config self.preset = preset
[docs] def get_model_kwargs(self, **kwargs): """Extract model kwargs from the preset.""" model_kwargs = {} # Forward `dtype` to model model_kwargs["dtype"] = kwargs.pop("dtype", None) # Forward `height` and `width` to model if "image_shape" in kwargs: model_kwargs["image_shape"] = kwargs.pop("image_shape", None) return model_kwargs, kwargs
[docs] def load_model(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError
[docs] def load_preprocessor(self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs): """Load a prepocessor layer from the preset.""" kwargs = cls._add_missing_kwargs(self, kwargs) return cls(**kwargs)
[docs] class KerasPresetLoader(PresetLoader): """Loader for Keras serialized presets."""
[docs] def check_model_class(self): """Check the model class is correct for the preset.""" return check_config_class(self.config)
[docs] def load_model(self, cls, load_weights, **kwargs): """Load a model from a serialized Keras config.""" model = load_serialized_object(self.config, **kwargs) if not load_weights: return model jax_memory_cleanup(model) # if model has a custom load_weights method, call it if hasattr(model, "custom_load_weights"): model.custom_load_weights(self.preset) return model # try to build with image_shape or input_shape if not built yet -> # but preferred way to build is to have a build_config in the json! if not model.built: if hasattr(model, "image_shape"): model.build(input_shape=model.image_shape) elif hasattr(model, "input_shape"): model.build(input_shape=model.input_shape) else: raise ValueError( "Model could not be built. Make sure to add a build_config to the json " "or set the input_shape or image_shape attribute before loading weights." ) model.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return model
[docs] def load_image_converter(self, cls, **kwargs): """Load an image converter from the preset.""" converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE) return load_serialized_object(converter_config, **kwargs)
[docs] def get_file(self, path): """Get a file from the preset.""" return get_file(self.preset, path)
[docs] def load_preprocessor(self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs): """Load a preprocessor from the preset.""" # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. if not check_file_exists(self.preset, config_file): return super().load_preprocessor(cls, **kwargs) preprocessor_json = load_json(self.preset, config_file) if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) if hasattr(preprocessor, "load_preset_assets"): preprocessor.load_preset_assets(self.preset) return preprocessor
[docs] class KerasPresetSaver: """Saver for Keras serialized presets.""" def __init__(self, preset_dir): """Initialize a preset saver.""" os.makedirs(preset_dir, exist_ok=True) self.preset_dir = preset_dir
[docs] def save_model(self, model): """Save a model to a preset.""" self._save_serialized_object(model, config_file=CONFIG_FILE) model_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) model.save_weights(model_weight_path) self._save_metadata(model)
[docs] def save_image_converter(self, converter): """Save an image converter to a preset.""" self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
[docs] def save_preprocessor(self, preprocessor): """Save a preprocessor to a preset.""" config_file = PREPROCESSOR_CONFIG_FILE if hasattr(preprocessor, "config_file"): config_file = preprocessor.config_file self._save_serialized_object(preprocessor, config_file) for layer in preprocessor._flatten_layers(include_self=False): if hasattr(layer, "save_to_preset"): layer.save_to_preset(self.preset_dir)
def _recursive_pop(self, config, key): """Remove a key from a nested config object""" config.pop(key, None) for value in config.values(): if isinstance(value, dict): self._recursive_pop(value, key) def _save_serialized_object(self, layer, config_file): config_path = os.path.join(self.preset_dir, config_file) config = keras.saving.serialize_keras_object(layer) config_to_skip = ["compile_config", "build_config"] for key in config_to_skip: self._recursive_pop(config, key) with open(config_path, "w", encoding="utf-8") as config_file: config_file.write(json.dumps(config, indent=4)) def _save_metadata(self, layer): zea_version = zea.__version__ keras_version = keras.version() if hasattr(keras, "version") else None metadata = { "keras_version": keras_version, "parameter_count": layer.count_params(), "zea_version": zea_version, "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), } metadata_path = os.path.join(self.preset_dir, METADATA_FILE) with open(metadata_path, "w", encoding="utf-8") as metadata_file: metadata_file.write(json.dumps(metadata, indent=4))
[docs] def get_preset_saver(preset): """Get a preset saver.""" # We only support one form of saving; Keras serialized # configs and saved weights. return KerasPresetSaver(preset)
[docs] def get_preset_loader(preset): """Get a preset loader.""" _assert_file_exists(preset, CONFIG_FILE) # We currently assume all formats we support have a `config.json`, this is # true, for Keras, Transformers, and timm. We infer the on disk format by # inspecting the `config.json` file. config = load_json(preset, CONFIG_FILE) if "registered_name" in config: # If we see registered_name, we assume a serialized Keras object. return KerasPresetLoader(preset, config) else: contents = json.dumps(config, indent=4) raise ValueError( f"Unrecognized format for {CONFIG_FILE} in {preset}. " "Create a preset with the `save_to_preset` utility on KerasHub " f"models. Contents of {CONFIG_FILE}:\n{contents}" )