zea.tensor_ops

Basic tensor operations implemented with the multi-backend keras.ops.

Functions

L2(x)

L2 norm of a tensor.

add_salt_and_pepper_noise(image, salt_prob)

Adds salt and pepper noise to the input image.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1D array slices along an axis.

batch_cov(x[, rowvar, bias, ddof])

Compute the batch covariance matrices of the input tensor.

batched_map(f, xs[, batch_size, jit])

Map a function over leading array axes.

boolean_mask(tensor, mask[, size])

Apply a boolean mask to a tensor.

check_patches_fit(image_shape, patch_shape, ...)

Checks if patches with overlap fit an integer amount in the original image.

compute_required_patch_overlap(image_shape, ...)

Compute required overlap between patches to cover the entire image.

compute_required_patch_shape(image_shape, ...)

Compute required patch shape to cover the entire image.

correlate(x, y[, mode])

Complex correlation via splitting real and imaginary parts.

extend_n_dims(arr, axis, n_dims)

Extend the number of dimensions of an array.

flatten(tensor[, start_dim, end_dim])

Should be similar to: https://pytorch.org/docs/stable/generated/torch.flatten.html

fori_loop(lower, upper, body_fun, init_val)

For loop allowing for non-jitted for loop with same signature as jax.

func_with_one_batch_dim(func, tensor, ...[, ...])

Wraps a function to apply it to an input tensor with one or more batch dimensions.

gaussian_filter(array, sigma[, order, mode, ...])

Multidimensional Gaussian filter.

gaussian_filter1d(array, sigma[, axis, ...])

1-D Gaussian filter.

images_to_patches(images, patch_shape[, overlap])

Creates patches from images.

interpolate_data(subsampled_data, mask[, ...])

Interpolate subsampled data along a specified axis using map_coordinates.

is_jax_prng_key(x)

Distinguish between jax.random.PRNGKey() and jax.random.key()

is_monotonic(array[, increasing])

Checks if a given 1D array is monotonic.

linear_sum_assignment(cost)

Greedy linear sum assignment.

map_indices_for_interpolation(indices)

Interpolates a 1D array of indices with gaps.

matrix_power(matrix, power)

Compute the power of a square matrix.

nonzero(x[, size, fill_value])

Return the indices of the elements that are non-zero.

pad_array_to_divisible(arr, N[, axis, mode, ...])

Pad an array to be divisible by N along the specified axis.

patched_map(f, xs, patches[, jit])

Wrapper around batched_map for patching.

patches_to_images(patches, image_shape[, ...])

Reconstructs images from patches.

resample(x, n_samples[, axis, order])

Resample tensor along axis.

reshape_axis(data, newshape, axis)

Reshape data along axis.

safe_vectorize(pyfunc[, excluded, signature])

Just a wrapper around ops.vectorize.

sinc(x[, eps])

Sinc function.

split_seed(seed, n)

Split a seed into n seeds for reproducible random ops.

split_volume_data_from_axis(data, ...)

Splits previously stacked tensor data back to its original shape.

stack_volume_data_along_axis(data, ...)

Stacks tensor data along a specified stack axis.

zea.tensor_ops.L2(x)[source]

L2 norm of a tensor.

Implementation of L2 norm: https://mathworld.wolfram.com/L2-Norm.html

zea.tensor_ops.add_salt_and_pepper_noise(image, salt_prob, pepper_prob=None, seed=None)[source]

Adds salt and pepper noise to the input image.

Parameters:
  • image (ndarray) – The input image, must be of type float32 and normalized between 0 and 1.

  • salt_prob (float) – The probability of adding salt noise to each pixel.

  • pepper_prob (float, optional) – The probability of adding pepper noise to each pixel. If not provided, it will be set to the same value as salt_prob.

  • seed – A Python integer or instance of keras.random.SeedGenerator. Used to make the behavior of the initializer deterministic. Note that an initializer seeded with an integer or None (unseeded) will produce the same random values across multiple calls. To get different random values across multiple calls, use as seed an instance of keras.random.SeedGenerator.

Returns:

The noisy image with salt and pepper noise added.

Return type:

ndarray

zea.tensor_ops.apply_along_axis(func1d, axis, arr, *args, **kwargs)[source]

Apply a function to 1D array slices along an axis.

Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map.

Parameters:
  • func1d – A callable function with signature func1d(arr, /, *args, **kwargs) where *args and **kwargs are the additional positional and keyword arguments passed to apply_along_axis.

  • axis – Integer axis along which to apply the function.

  • arr – The array over which to apply the function.

  • *args – Additional positional arguments passed through to func1d.

  • **kwargs – Additional keyword arguments passed through to func1d.

Returns:

The result of func1d applied along the specified axis.

zea.tensor_ops.batch_cov(x, rowvar=True, bias=False, ddof=None)[source]

Compute the batch covariance matrices of the input tensor.

Parameters:
  • x (Tensor) – Input tensor of shape (…, m, n) where m is the number of features and n is the number of observations.

  • rowvar (bool, optional) – If True, each row represents a variable, while each column represents an observation. If False, each column represents a variable, while each row represents an observation. Defaults to True.

  • bias (bool, optional) – If True, the biased estimator of the covariance is computed. If False, the unbiased estimator is computed. Defaults to False.

  • ddof (int, optional) – Delta degrees of freedom. The divisor used in the calculation is (num_obs - ddof), where num_obs is the number of observations. If ddof is not specified, it is set to 0 if bias is True, and 1 if bias is False. Defaults to None.

Returns:

Batch covariance matrices of shape (…, m, m) if rowvar=True,

or (…, n, n) if rowvar=False.

Return type:

Tensor

zea.tensor_ops.batched_map(f, xs, batch_size=None, jit=True, **batch_kwargs)[source]

Map a function over leading array axes.

Parameters:
  • f (callable) – Function to apply element-wise over the first axis.

  • xs (Tensor) – Values over which to map along the leading axis.

  • batch_size (int, optional) – Size of the batch for each step. Defaults to None, in which case the function will be equivalent to ops.map, and thus map over the leading axis.

  • jit (bool, optional) – If True, use a jitted version of the function for faster batched mapping. Else, loop over the data with the original function.

  • batch_kwargs (dict, optional) – Additional keyword arguments (tensors) to batch along with xs. Must have the same first dimension size as xs.

Returns:

The mapped tensor(s).

Idea taken from: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html

zea.tensor_ops.boolean_mask(tensor, mask, size=None)[source]

Apply a boolean mask to a tensor.

Parameters:
  • tensor (Tensor) – The input tensor.

  • mask (Tensor) – The boolean mask to apply.

  • size (int, optional) – The size of the output tensor. Only used for Jax backend if you want to trace the function. Defaults to None.

Returns:

The masked tensor.

Return type:

Tensor

zea.tensor_ops.check_patches_fit(image_shape, patch_shape, overlap)[source]

Checks if patches with overlap fit an integer amount in the original image.

Parameters:
  • image_shape (tuple) – A tuple representing the shape of the original image.

  • patch_size – A tuple representing the shape of the patches.

  • overlap (Union[int, Tuple[int, int]]) – A float representing the overlap between patches.

Return type:

tuple

Returns:

A tuple containing a boolean indicating if the patches fit an integer amount in the original image and the new image shape if the patches do not fit.

Example

image_shape = (10, 10)
patch_shape = (4, 4)
overlap = (2, 2)
patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap)
patches_fit
new_shape
zea.tensor_ops.compute_required_patch_overlap(image_shape, patch_shape)[source]

Compute required overlap between patches to cover the entire image.

Parameters:
  • image_shape – Tuple of (height, width)

  • patch_shape – Tuple of (patch_height, patch_width)

Returns:

Tuple of (overlap_y, overlap_x)

Or None if there is no overlap that will result in integer number of patches given the image and patch shapes.

zea.tensor_ops.compute_required_patch_shape(image_shape, patch_shape, overlap)[source]

Compute required patch shape to cover the entire image.

Compute patch_shape closest to the original patch_shape that will result in integer number of patches given the image and overlap.

Parameters:
  • image_shape – Tuple of (height, width)

  • patch_shape – Tuple of (patch_height, patch_width)

  • overlap – Tuple of (overlap_y, overlap_x)

Returns:

Tuple of (patch_shape_y, patch_shape_x)

or None if there is no patch_shape that will result in integer number of patches given the image and overlap.

zea.tensor_ops.correlate(x, y, mode='full')[source]

Complex correlation via splitting real and imaginary parts. Equivalent to np.correlate(x, y, mode).

NOTE: this function exists because tensorflow does not support complex correlation. NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.

Parameters:
  • x – np.ndarray (complex or real)

  • y – np.ndarray (complex or real)

  • mode – “full”, “valid”, or “same”

zea.tensor_ops.extend_n_dims(arr, axis, n_dims)[source]

Extend the number of dimensions of an array.

Inserts ‘n_dims’ ones at the specified axis.

Parameters:
  • arr – The input array.

  • axis – The axis at which to insert the new dimensions.

  • n_dims – The number of dimensions to insert.

Returns:

The array with the extended number of dimensions.

Raises:

AssertionError – If the axis is out of range.

zea.tensor_ops.flatten(tensor, start_dim=0, end_dim=-1)[source]

Should be similar to: https://pytorch.org/docs/stable/generated/torch.flatten.html

zea.tensor_ops.fori_loop(lower, upper, body_fun, init_val, disable_jit=False)[source]

For loop allowing for non-jitted for loop with same signature as jax.

Parameters:
  • lower (int) – Lower bound of the loop.

  • upper (int) – Upper bound of the loop.

  • body_fun (function) – Function to be executed in the loop.

  • init_val (any) – Initial value for the loop.

  • disable_jit (bool, optional) – If True, disables JIT compilation. Defaults to False.

zea.tensor_ops.func_with_one_batch_dim(func, tensor, n_batch_dims, batch_size=None, func_axis=None, **kwargs)[source]

Wraps a function to apply it to an input tensor with one or more batch dimensions.

The function will be executed in parallel on all batch elements.

Parameters:
  • func (function) – The function to apply to the image. Will take the func_axis output from the function.

  • tensor (Tensor) – The input tensor.

  • n_batch_dims (int) – The number of batch dimensions in the input tensor. Expects the input to start with n_batch_dims batch dimensions. Defaults to 2.

  • batch_size (int, optional) – Integer specifying the size of the batch for each step to execute in parallel. Defaults to None, in which case the function will run everything in parallel.

  • func_axis (int, optional) – If func returns mulitple outputs, this axis will be returned.

  • **kwargs – Additional keyword arguments to pass to the function.

Returns:

The output tensor with the same batch dimensions as the input tensor.

Raises:

ValueError – If the number of batch dimensions is greater than the rank of the input tensor.

zea.tensor_ops.gaussian_filter(array, sigma, order=0, mode='symmetric', cval=None, truncate=4.0, axes=None)[source]

Multidimensional Gaussian filter.

If you want to use this function with jax.jit, you can set: static_argnames=(“truncate”, “sigma”)

Parameters:
  • array (Tensor) – The input array.

  • sigma (float or tuple) – Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes.

  • order (int or Tuple[int]) – The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0.

  • mode (str, optional) – Padding mode for the input image. Default is ‘symmetric’. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensoflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter!

  • cval (float, optional) – Value to fill past edges of input if mode is ‘constant’. Default is None.

  • truncate (float, optional) – Truncate the filter at this many standard deviations. Default is 4.0.

  • axes (Tuple[int], optional) – If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes.

zea.tensor_ops.gaussian_filter1d(array, sigma, axis=-1, order=0, mode='symmetric', truncate=4.0, cval=None)[source]

1-D Gaussian filter.

Parameters:
  • array (Tensor) – The input array.

  • sigma (float or tuple) – Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes.

  • order (int or Tuple[int]) – The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0.

  • mode (str, optional) – Padding mode for the input image. Default is ‘symmetric’. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensoflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter!

  • cval (float, optional) – Value to fill past edges of input if mode is ‘constant’. Default is None.

  • truncate (float, optional) – Truncate the filter at this many standard deviations. Default is 4.0.

  • axes (Tuple[int], optional) – If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes.

zea.tensor_ops.images_to_patches(images, patch_shape, overlap=None)[source]

Creates patches from images.

Parameters:
  • images (Tensor) – input images [batch, height, width, channels].

  • patch_shape (int or tuple, optional) – Height and width of patch. Defaults to 4.

  • overlap (int or tuple, optional) – Overlap between patches in px. Defaults to None.

Returns:

batch of patches of size:

[batch, #patch_y, #patch_x, patch_size_y, patch_size_x, #channels].

Return type:

patches (Tensor)

Example

import keras

images = keras.random.uniform((2, 8, 8, 3))
patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2))
patches.shape
zea.tensor_ops.interpolate_data(subsampled_data, mask, order=1, axis=-1)[source]

Interpolate subsampled data along a specified axis using map_coordinates.

Parameters:
  • subsampled_data (ndarray) – The data subsampled along the specified axis. Its shape matches mask except along the subsampled axis.

  • mask (ndarray) – Boolean array with the same shape as the full data. True where data is known.

  • order (int, optional) – The order of the spline interpolation. Default is 1.

  • axis (int, optional) – The axis along which the data is subsampled. Default is -1.

Returns:

The data interpolated back to the original grid.

Return type:

ndarray

ValueError: If mask does not indicate any missing data or if mask has False

values along multiple axes.

zea.tensor_ops.is_jax_prng_key(x)[source]

Distinguish between jax.random.PRNGKey() and jax.random.key()

zea.tensor_ops.is_monotonic(array, increasing=True)[source]

Checks if a given 1D array is monotonic.

Either entirely non-decreasing or non-increasing.

Parameters:

array (ndarray) – A 1D numpy array.

Returns:

True if the array is monotonic, False otherwise.

Return type:

bool

zea.tensor_ops.linear_sum_assignment(cost)[source]

Greedy linear sum assignment.

Parameters:

cost (Tensor) – Cost matrix of shape (n, n).

Returns:

Row indices and column indices for assignment.

Return type:

Tuple

Returns row indices and column indices for assignment.

zea.tensor_ops.map_indices_for_interpolation(indices)[source]

Interpolates a 1D array of indices with gaps.

Maps a 1D array of indices with gaps to a 1D array where gaps would be between the integers.

Used in the interpolate_data function.

Parameters:

(indices) – A 1D array of indices with gaps.

Returns:

mapped to a 1D array where gaps would be between the integers

Return type:

(indices)

There are two segments here of length 4 and 2

Example

>>> indices = [5, 6, 7, 8, 12, 13, 19]
>>> mapped_indices = [5, 5.25, 5.5, 5.75, 8, 8.5, 12.5]
zea.tensor_ops.matrix_power(matrix, power)[source]

Compute the power of a square matrix.

Should match the [numpy](https://numpy.org/doc/stable/reference/generated/numpy.linalg.matrix_power.html) implementation.

Parameters:
  • matrix (array-like) – A square matrix to be raised to a power.

  • power (int) – The exponent to which the matrix is to be raised. Must be a non-negative integer.

Returns:

The resulting matrix after raising the input matrix to the specified power.

Return type:

array-like

zea.tensor_ops.nonzero(x, size=None, fill_value=None)[source]

Return the indices of the elements that are non-zero.

zea.tensor_ops.pad_array_to_divisible(arr, N, axis=0, mode='constant', pad_value=None)[source]

Pad an array to be divisible by N along the specified axis.

Parameters:
  • arr (Tensor) – The input array to pad.

  • N (int) – The number to which the length of the specified axis should be divisible.

  • axis (int, optional) – The axis along which to pad the array. Defaults to 0.

  • mode (str, optional) – The padding mode to use. Defaults to ‘constant’. One of “constant”, “edge”, “linear_ramp”, “maximum”, “mean”, “median”, “minimum”, “reflect”, “symmetric”, “wrap”, “empty”, “circular”. Defaults to “constant”.

  • pad_value (float, optional) – The value to use for padding when mode=’constant’. Defaults to None. If mode is not constant, this value should be None.

Returns:

The padded array.

Return type:

Tensor

zea.tensor_ops.patched_map(f, xs, patches, jit=True, **batch_kwargs)[source]

Wrapper around batched_map for patching.

Allows you to specify the number of patches rather than the batch size.

zea.tensor_ops.patches_to_images(patches, image_shape, overlap=None, window_type='average')[source]

Reconstructs images from patches.

Parameters:
  • patches (Tensor) – Array with batch of patches to convert to batch of images. [batch_size, #patch_y, #patch_x, patch_size_y, patch_size_x, n_channels]

  • image_shape (Tuple) – Shape of output image. (height, width, channels)

  • overlap (int or tuple, optional) – Overlap between patches in px. Defaults to None.

  • window_type (str, optional) – Type of stitching to use. Defaults to ‘average’. Options: ‘average’, ‘replace’.

Returns:

Reconstructed batch of images from batch of patches.

Return type:

images (Tensor)

Example

import keras

patches = keras.random.uniform((2, 3, 3, 4, 4, 3))
images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2))
images.shape
zea.tensor_ops.resample(x, n_samples, axis=-2, order=1)[source]

Resample tensor along axis.

Similar to scipy.signal.resample.

Parameters:
  • x – input tensor.

  • n_samples – number of samples after resampling.

  • axis – axis to resample along.

  • order – interpolation order (1=linear).

Returns:

Resampled tensor.

zea.tensor_ops.reshape_axis(data, newshape, axis)[source]

Reshape data along axis.

Parameters:
  • data (tensor) – input data.

  • newshape (tuple) – new shape of data along axis.

  • axis (int) – axis to reshape.

Example

import keras

data = keras.random.uniform((3, 4, 5))
newshape = (2, 2)
reshaped_data = reshape_axis(data, newshape, axis=1)
reshaped_data.shape
zea.tensor_ops.safe_vectorize(pyfunc, excluded=None, signature=None)[source]

Just a wrapper around ops.vectorize.

Because tensorflow does not support multiple arguments to ops.vectorize(func)(…) We will just map the function manually.

zea.tensor_ops.sinc(x, eps=1e-07)[source]

Sinc function.

zea.tensor_ops.split_seed(seed, n)[source]

Split a seed into n seeds for reproducible random ops.

Supports keras.random.SeedGenerator and JAX random keys.

Parameters:
  • seed – None, jax.Array, or keras.random.SeedGenerator.

  • n (int) – Number of seeds to generate.

Returns:

List of n seeds (JAX keys, SeedGenerator, or None).

Return type:

list

zea.tensor_ops.split_volume_data_from_axis(data, batch_axis, stack_axis, number, padding)[source]

Splits previously stacked tensor data back to its original shape.

This function reverses the operation performed by stack_volume_data_along_axis.

Parameters:
  • data (Tensor) – Input tensor to be split.

  • batch_axis (int) – Axis along which to restore the blocks.

  • stack_axis (int) – Axis from which to split the stacked data.

  • number (int) – Number of slices per stack.

  • padding (int) – Amount of padding to remove from the result.

Returns:

Reshaped tensor with data split back to original format.

Return type:

Tensor

Example

import keras

data = keras.random.uniform((20, 10, 30))
split_data = split_volume_data_from_axis(data, 0, 1, 2, 2)
split_data.shape
zea.tensor_ops.stack_volume_data_along_axis(data, batch_axis, stack_axis, number)[source]

Stacks tensor data along a specified stack axis.

Stack tensor data along a specified stack axis by splitting it into blocks along the batch axis.

Parameters:
  • data (Tensor) – Input tensor to be stacked.

  • batch_axis (int) – Axis along which to split the data into blocks.

  • stack_axis (int) – Axis along which to stack the blocks.

  • number (int) – Number of slices per stack.

Returns:

Reshaped tensor with data stacked along stack_axis.

Return type:

Tensor

Example

import keras

data = keras.random.uniform((10, 20, 30))
# stacking along 1st axis with 2 frames per block
stacked_data = stack_volume_data_along_axis(data, 0, 1, 2)
stacked_data.shape