Source code for zea.config

"""Config module for managing configuration settings.

This module provides the :class:`Config` class for managing configuration settings,
with support for loading from YAML files, HuggingFace Hub, and dot notation access.

Features
--------

- Dot notation access to dictionary keys.
- Recursive conversion of nested dictionaries/lists to Config objects.
- Attribute access logging and suggestion of similar attribute names.
- Freezing/unfreezing to prevent/allow new attributes.
- Serialization to YAML/JSON.
- Integration with Hugging Face Hub.

Example Usage
^^^^^^^^^^^^^

.. code-block:: python

    from zea import Config

    # Load from YAML
    config = Config.from_yaml("config.yaml")
    # Load from HuggingFace Hub
    config = Config.from_hf("zea/diffusion-echonet-dynamic", "train_config.yaml")

    # Access attributes with dot notation
    print(config.model.name)

    # Update recursively
    config.update_recursive({"model": {"name": "new_model"}})

    # Save to YAML
    config.save_to_yaml("new_config.yaml")

"""

import copy
import difflib
import inspect
import json
from pathlib import Path
from typing import Union

import yaml
from huggingface_hub import hf_hub_download

from zea import log
from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
from zea.internal.config.validation import config_schema
from zea.internal.core import dict_to_tensor


[docs] class Config(dict): """Config class. This Config class extends a normal dictionary with dot notation access. Features: - `Config.from_yaml` method to load a config from a yaml file. - `Config.from_hf` method to load a config from a huggingface hub. - `save_to_yaml` method to save the config to a yaml file. - `copy` method to create a deep copy of the config. - Normal dictionary methods such as `keys`, `values`, `items`, `pop`, `update`, `get`. - Propose similar attribute names if a non-existing attribute is accessed. - Freeze the config object to prevent new attributes from being added. - Load config object from yaml file. - Logs all accessed attributes such that you can check if all attributes have been accessed. We took inspiration from the following sources: - `EasyDict <https://pypi.org/project/easydict/>`_ - `keras.utils.Config <https://keras.io/api/utils/experiment_management_utils/#config-class>`_ But this implementation is superior :) """ # noqa: E501 __frozen__ = False def __init__(self, dictionary=None, __parent__=None, **kwargs): """ Initializes a Config object. Args: dictionary (dict, optional): A dictionary containing key-value pairs to initialize the Config object. Defaults to None. **kwargs: Additional key-value pairs to initialize the Config object. Will override values in the dictionary if they have the same key. """ # Get all methods of the Config class and store them in a list as protected attributes super().__setattr__( "__protected__", [x[0] for x in inspect.getmembers(Config, predicate=inspect.isroutine)] + ["__protected__", "__accessed__", "__parent__"], ) super().__setattr__("__accessed__", {}) super().__setattr__("__parent__", __parent__) if dictionary is None: dictionary = {} if kwargs: dictionary.update(**kwargs) for k, v in dictionary.items(): setattr(self, k, v)
[docs] def clear(self): """Clear the config object.""" super().clear() self._reset_accessed()
[docs] def fromkeys(self, keys, value=None): """Returns a config with the specified keys and value""" super().fromkeys(keys, value)
[docs] def get(self, key, default=None): """Returns the value of the specified key""" self._mark_accessed(key) return super().get(key, default)
[docs] def items(self): """Returns a list containing a tuple for each key value pair""" # Use a generator that calls __getitem__ for every key return [(key, self[key]) for key in self.keys()]
[docs] def keys(self): """Returns a list containing the config's keys""" return super().keys()
[docs] def pop(self, key, default=None): """Removes the element with the specified key""" self._mark_accessed(key) return super().pop(key, default)
[docs] def popitem(self): """Removes the last inserted key-value pair""" key, value = super().popitem() self._mark_accessed(key) return key, value
[docs] def setdefault(self, key, default=None): """Returns the value of the specified key. If the key does not exist: insert the key, with the specified value""" # Use __getitem__ to get values and __setitem__ to set values if key not in self: self[key] = default return self[key]
[docs] def update(self, dictionary: dict | None = None, **kwargs): """Updates the config with the specified key-value pairs""" # Use __setitem__ to set values if dictionary is None: dictionary = {} dictionary.update(kwargs) for key, value in dictionary.items(): self[key] = value
[docs] def update_recursive(self, dictionary: dict | None = None, **kwargs): """Recursively update the config with the provided dictionary and keyword arguments. If a key corresponds to another Config object, the update_recursive method is called recursively on that object. This makes it possible to update nested Config objects without replacing them. If a value is a list and the corresponding config value is also a list, each element is updated recursively if it is a Config, otherwise replaced. Example: .. code-block:: python config = Config({"a": 1, "b": {"c": 2, "d": 3}}) config.update_recursive({"a": 4, "b": {"c": 5}}) print(config) # <Config {'a': 4, 'b': {'c': 5, 'd': 3}}> # Notice how "d" is kept and only "c" is updated. Args: dictionary (dict, optional): Dictionary to update from. **kwargs: Additional key-value pairs to update. """ if dictionary is None: dictionary = {} dictionary.update(kwargs) for key, value in dictionary.items(): if key in self and isinstance(self[key], Config): self[key].update_recursive(value) elif key in self and isinstance(value, list): for i, v in enumerate(value): if isinstance(v, Config): self[key][i].update_recursive(v) else: self[key][i] = v else: self[key] = value
[docs] def values(self): """Returns a list of all the values in the config""" # Use __getitem__ to get values return (self[key] for key in self.keys())
def __or__(self, other): """ `self | other` operator. Returns a new config object with the contents of both configs. """ return Config(super().__or__(other)) def __ror__(self, other): """ `other | self` operator. Returns a new config object with the contents of both configs. """ return Config(super().__ror__(other)) def __ior__(self, other): """ `|=` operator. Updates the config with the contents of the other config. """ self.update(other) return self def __iter__(self): """Returns an iterator that iterates through the keys of the config""" # Overwritten to ensure iteration respects our logic return iter(self.keys()) def __contains__(self, key): """Returns True if the specified key exists in the config""" return super().__contains__(key) def __setattr__(self, name, value): # Check if attribute is a method of the Config class, this cannot be overridden if hasattr(self, "__protected__") and name in self.__protected__: raise AttributeError(f"Cannot set attribute `{name}`. It is used by the Config class.") # Check if config is frozen if self.__frozen__ and not hasattr(self, name): raise TypeError( f"Config is a frozen, no new attributes can be added. Tried to add: `{name}`" ) # If overriding an existing attribute, mark it as unaccessed self._mark_unaccessed(name) # Convert tuple to list to allow for item assignment if isinstance(value, tuple): value = list(value) # Ensures lists and tuples of dictionaries are converted to Config objects as well if isinstance(value, list): value = [ self.__class__(x, __parent__=self) if isinstance(x, dict) else x for x in value ] # Ensures dictionaries are converted to Config objects as well elif isinstance(value, dict): value = self.__class__(value, __parent__=self) super().__setitem__(name, value) def __setitem__(self, key, value): return self.__setattr__(key, value) def _unknown_attr(self, name): msg = f"Unknown attribute: '{name}'." if "difflib" in globals(): closest_matches = difflib.get_close_matches(name, self.keys(), n=1, cutoff=0.7) if closest_matches: msg += f" Did you mean '{closest_matches[0]}'?" return msg def _reset_accessed(self): """Reset accessed attributes.""" self._recursive_setattr("__accessed__", {}) def _mark_accessed(self, name): """Mark an attribute as accessed.""" if name in self: self.__accessed__[name] = True def _mark_unaccessed(self, name): """Mark an attribute as unaccessed.""" if name in self.__accessed__: del self.__accessed__[name] def _mark_accessed_recursive(self): """Mark an attribute and all its children as accessed.""" def mark_accessed(config, key, value): config._mark_accessed(key) return key, value self.as_dict(mark_accessed) def _dict_items(self): """Return the items of the config object. Only used for internal purposes.""" return super().items() def _trace_through_ancestors(self, key_trace=None): """Find the root ancestor of the config object.""" if key_trace is None: key_trace = [] if self.__parent__ is None: return self, key_trace for key, value in self.__parent__._dict_items(): if isinstance(value, list): for i, v in enumerate(value): if v == self: return self.__parent__._trace_through_ancestors([key + f"_{i}"] + key_trace) if value == self: return self.__parent__._trace_through_ancestors([key] + key_trace) raise ValueError("Parent not found in ancestors. Report to zea developers.") @staticmethod def _assert_key_accessed(config, key, value, _assert=True): """Assert that a key has been accessed.""" if key not in config.__accessed__: key_trace = config._trace_through_ancestors()[1] msg = f"Attribute '{key}'='{value}' has not been accessed." if key_trace: msg += f" Has ancestors through '{key_trace}'" if _assert: raise AssertionError(msg) log.warning(msg) return key, value def _assert_all_accessed(self): """Assert that all attributes have been accessed.""" # Temporary remove parent to avoid recursion if not being called from ancestor. self._all_unaccessed(_assert=True) def _log_all_unaccessed(self): """Log all unaccessed attributes.""" self._all_unaccessed(_assert=False) def _all_unaccessed(self, _assert=False): """Assert or log all unaccessed attributes.""" # Temporary remove parent to avoid recursion if not being called from ancestor. parent = self.__parent__ super().__setattr__("__parent__", None) self.as_dict(lambda *args: self._assert_key_accessed(*args, _assert=_assert)) super().__setattr__("__parent__", parent) def __getattr__(self, name): if name in self: self._mark_accessed(name) return super().__getitem__(name) msg = self._unknown_attr(name) raise AttributeError(msg) def __getitem__(self, key): if key in self: self._mark_accessed(key) return super().__getitem__(key) msg = self._unknown_attr(key) raise KeyError(msg) def __delattr__(self, name): del self[name] def __repr__(self): return f"<Config {self.as_dict()}>"
[docs] def to_json(self): """Return the config as a json string.""" return json.dumps(self)
[docs] def as_dict(self, func_on_leaves=None): """Convert the config to a normal dictionary (recursively). Args: func_on_leaves (callable, optional): Function to apply to each leaf node. The function should take three arguments: the config object, the key, and the value. You can change the key and value inside the function. Defaults to None. """ dictionary = {} for key, value in self._dict_items(): if isinstance(value, Config): value = value.as_dict(func_on_leaves) elif isinstance(value, list): value = [v.as_dict(func_on_leaves) if isinstance(v, Config) else v for v in value] # a dict does not exist inside a Config object, because it is a Config object itself if func_on_leaves: key, value = func_on_leaves(self, key, value) dictionary[key] = value return dictionary
[docs] def serialize(self): """Return a dict of this config object with all Path objects converted to strings.""" return self.as_dict(lambda _, key, value: (key, _path_to_str(value)))
[docs] def copy(self): """Deep copy the config object. This is useful when you want to modify the config object without changing the original. Does not preserve the access history or frozen state! """ return Config(copy.deepcopy(self.as_dict()))
[docs] def save_to_yaml(self, path): """Save config contents to yaml""" with open(Path(path), "w", encoding="utf-8") as save_file: yaml.dump( self.serialize(), save_file, default_flow_style=False, sort_keys=False, )
[docs] def freeze(self): """Freeze config object. This means that no new attributes can be added. Only existing attributes can be modified. """ self._recursive_setattr("__frozen__", True)
[docs] def unfreeze(self): """Unfreeze config object. This means that new attributes can be added.""" self._recursive_setattr("__frozen__", False)
def _recursive_setattr(self, set_key, set_value): """Helper function to recursively set an attribute on all nested configs.""" super().__setattr__(set_key, set_value) for _, value in self._dict_items(): if isinstance(value, Config): value._recursive_setattr(set_key, set_value) elif isinstance(value, list): for v in value: if isinstance(v, Config): v._recursive_setattr(set_key, set_value)
[docs] @classmethod def from_path(cls, path, **kwargs): """Load config object from a file path. Args: path (str or Path): The path to the config file. Can be a string or a Path object. Additionally can be a string with the prefix 'hf://', in which case it will be resolved to a huggingface path. Returns: Config: config object. """ if str(path).startswith(HF_PREFIX): path = _hf_resolve_path(str(path)) if isinstance(path, str): path = Path(path) return _load_config_from_yaml(path, config_class=cls, **kwargs)
[docs] @classmethod def from_hf(cls, repo_id, path, **kwargs): """Load config object from huggingface hub. Example: .. code-block:: python config = Config.from_hf("zeahub/configs", "config_camus.yaml", repo_type="dataset") Args: repo_id (str): huggingface hub repo id. For example: "zeahub/configs" path (str): path to the config file in the repo. For example: "train_config.yaml" **kwargs: additional arguments to pass to the `hf_hub_download` function. For example, use repo_type="dataset" to download from a dataset repo, or revision="main" to download from a specific branch. Returns: Config: config object. """ local_path = hf_hub_download(repo_id, path, **kwargs) return _load_config_from_yaml(local_path, config_class=cls)
[docs] @classmethod def from_yaml(cls, path, **kwargs): """Load config object from yaml file.""" return cls.from_path(path, **kwargs)
[docs] def to_tensor(self, keep_as_is=None): """Convert the attributes in the object to keras tensors""" return dict_to_tensor(self.serialize(), keep_as_is=keep_as_is)
[docs] def check_config(config: Union[dict, Config], verbose: bool = False): """Check a config given dictionary""" def _try_validate_config(config): try: config = config_schema.validate(config) return config except Exception as e: log.error(f"Config is not valid: {e}") raise e assert type(config) in [ dict, Config, ], f"Config must be a dictionary or Config object, not {type(config)}" if isinstance(config, Config): config = config.serialize() config = _try_validate_config(config) config = Config(config) config.freeze() # freeze because schema will add all defaults else: config = _try_validate_config(config) if verbose: log.success("Config is correct") return config
def _load_config_from_yaml(path, config_class=Config, loader=yaml.FullLoader): """Load config object from yaml file Args: path (str): path to yaml file. loader (yaml.Loader, optional): yaml loader. Defaults to yaml.FullLoader. for custom objects, you might want to use yaml.UnsafeLoader. config_class (type, optional): Config class to instantiate. Defaults to Config. Returns: Config: config object. """ with open(Path(path), "r", encoding="utf-8") as file: dictionary = yaml.load(file, Loader=loader) if dictionary: return config_class(dictionary) else: return config_class() def _path_to_str(path): """Convert a Path object to a string.""" if hasattr(path, "as_posix"): # If path is a Path object, convert to string path = path.as_posix() return path