zea.backend

Backend-specific utilities.

This subpackage provides backend-specific utilities for the zea library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.

Note

Most backend-specific logic is handled by Keras 3, so this subpackage is intentionally minimal. Only features not natively supported by Keras (such as JIT and autograd) are implemented here.

Key Features

  • JIT Compilation (zea.backend.jit()): Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the torch backend.

  • Automatic Differentiation (zea.backend.AutoGrad): Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.

  • Backend Submodules:

  • Data Loading (zea.backend.tensorflow.make_dataloader()): This function is implemented using TensorFlow’s efficient data pipeline utilities. It provides a convenient way to load and preprocess data for machine learning workflows, leveraging TensorFlow’s tf.data.Dataset API.

Functions

jit([func, jax, tensorflow])

Applies JIT compilation to the given function based on the current Keras backend.

tf_function([func, jit_compile])

Applies default tf.function to the given function.

zea.backend.jit(func=None, jax=True, tensorflow=True, **kwargs)[source]

Applies JIT compilation to the given function based on the current Keras backend. Can be used as a decorator or as a function.

Parameters:
  • func (callable) – The function to be JIT compiled.

  • jax (bool) – Whether to enable JIT compilation in the JAX backend.

  • tensorflow (bool) – Whether to enable JIT compilation in the TensorFlow backend.

  • **kwargs – Keyword arguments to be passed to the JIT compiler.

Returns:

The JIT-compiled function.

Return type:

callable

zea.backend.tf_function(func=None, jit_compile=False, **kwargs)[source]

Applies default tf.function to the given function. Only in TensorFlow backend.

Modules

autograd

Autograd wrapper for different backends.

jax

Jax utilities for zea.

tensorflow

Tensorflow Ultrasound Beamforming Library.

torch

Pytorch Ultrasound Beamforming Library.