zea.backend.jax

Jax utilities for zea.

Functions

on_device_jax(func, inputs, device[, ...])

Applies a JAX function to inputs on a specified device.

zea.backend.jax.on_device_jax(func, inputs, device, return_numpy=False, **kwargs)[source]

Applies a JAX function to inputs on a specified device.

Parameters:
  • func (callable) – The function to apply.

  • inputs (ndarray) – Input array.

  • device (str) – Device string, e.g. 'cuda', 'gpu', or 'cpu'.

  • return_numpy (bool, optional) – Whether to convert output data back to numpy. Defaults to False.

  • **kwargs – Additional keyword arguments to be passed to the func.

Returns:

The output data.

Return type:

jax.numpy.DeviceArray or ndarray

Raises:

AssertionError – If func is not a function from the JAX library.

Note

This function converts the inputs array to a JAX array and moves it to the specified device. It then applies the func function to the inputs and returns the output data. If the output is a dictionary, it extracts the first value from the dictionary. If return_numpy is True, it converts the output data back to a numpy array before returning.

Example

import jax.numpy as jnp


def square(x):
    return x**2


inputs = [1, 2, 3, 4, 5]
device = "gpu"
output = on_device_jax(square, inputs, device)