"""Backend-specific utilities.This subpackage provides backend-specific utilities for the ``zea`` library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch... note:: Most backend-specific logic is handled by Keras 3, so this subpackage is intentionally minimal. Only features not natively supported by Keras (such as JIT and autograd) are implemented here.Key Features------------- **JIT Compilation** (:func:`zea.backend.jit`): Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the `torch` backend.- **Automatic Differentiation** (:class:`zea.backend.AutoGrad`): Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.- **Backend Submodules:** - :mod:`zea.backend.jax` -- JAX-specific utilities and device management. - :mod:`zea.backend.torch` -- PyTorch-specific utilities and device management. - :mod:`zea.backend.tensorflow` -- TensorFlow-specific utilities, and device management, as well as data loading utilities.- **Data Loading** (:func:`zea.backend.tensorflow.make_dataloader`): This function is implemented using TensorFlow's efficient data pipeline utilities. It provides a convenient way to load and preprocess data for machine learning workflows, leveraging TensorFlow's ``tf.data.Dataset`` API."""importkerasfromzeaimportlogdef_import_tf():try:importtensorflowastfreturntfexceptImportError:returnNonedef_import_jax():try:importjaxreturnjaxexceptImportError:returnNonedef_import_torch():try:importtorchreturntorchexceptImportError:returnNonetf_mod=_import_tf()jax_mod=_import_jax()
[docs]deftf_function(func=None,jit_compile=False,**kwargs):"""Applies default tf.function to the given function. Only in TensorFlow backend."""returnjit(func,jax=False,jit_compile=jit_compile,**kwargs)
[docs]defjit(func=None,jax=True,tensorflow=True,**kwargs):""" Applies JIT compilation to the given function based on the current Keras backend. Can be used as a decorator or as a function. Args: func (callable): The function to be JIT compiled. jax (bool): Whether to enable JIT compilation in the JAX backend. tensorflow (bool): Whether to enable JIT compilation in the TensorFlow backend. **kwargs: Keyword arguments to be passed to the JIT compiler. Returns: callable: The JIT-compiled function. """iffuncisNone:defdecorator(func):return_jit_compile(func,jax=jax,tensorflow=tensorflow,**kwargs)returndecoratorelse:return_jit_compile(func,jax=jax,tensorflow=tensorflow,**kwargs)
def_jit_compile(func,jax=True,tensorflow=True,**kwargs):backend=keras.backend.backend()ifbackend=="tensorflow"andtensorflow:iftf_modisNone:raiseImportError("TensorFlow is not installed. Please install it to use this backend.")jit_compile=kwargs.pop("jit_compile",True)returntf_mod.function(func,jit_compile=jit_compile,**kwargs)elifbackend=="jax"andjax:ifjax_modisNone:raiseImportError("JAX is not installed. Please install it to use this backend.")returnjax_mod.jit(func,**kwargs)elifbackend=="tensorflow"andnottensorflow:returnfuncelifbackend=="jax"andnotjax:returnfuncelse:log.warning(f"JIT compilation not currently supported for backend {backend}. ""Supported backends are 'tensorflow' and 'jax'.")log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")log.warning("Falling back to non-compiled mode.")returnfunc