Source code for zea.tensor_ops

"""Basic tensor operations implemented with the multi-backend ``keras.ops``."""

from typing import Tuple, Union

import keras
import numpy as np
from keras import ops
from scipy.ndimage import _ni_support
from scipy.ndimage._filters import _gaussian_kernel1d

from zea import log
from zea.utils import map_negative_indices


[docs] def split_seed(seed, n): """Split a seed into n seeds for reproducible random ops. Supports `keras.random.SeedGenerator <https://keras.io/api/random/#seedgenerator-class>`_ and `JAX random keys <https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey>`_. Args: seed: None, jax.Array, or keras.random.SeedGenerator. n (int): Number of seeds to generate. Returns: list: List of n seeds (JAX keys, SeedGenerator, or None). """ # If seed is None, return a list of None if seed is None: return [None for _ in range(n)] # If seed is a JAX key, split it into n keys if keras.backend.backend() == "jax": import jax return jax.random.split(seed, n) # For other backends, we have to use Keras SeedGenerator else: assert isinstance(seed, keras.random.SeedGenerator), ( "seed must be a SeedGenerator when not using JAX." ) # Just duplicate the SeedGenerator return [seed for _ in range(n)]
[docs] def is_jax_prng_key(x): """Distinguish between jax.random.PRNGKey() and jax.random.key()""" if keras.backend.backend() == "jax": import jax return isinstance(x, jax.Array) and x.shape == (2,) and x.dtype == jax.numpy.uint32 else: return False
[docs] def add_salt_and_pepper_noise(image, salt_prob, pepper_prob=None, seed=None): """Adds salt and pepper noise to the input image. Args: image (ndarray): The input image, must be of type float32 and normalized between 0 and 1. salt_prob (float): The probability of adding salt noise to each pixel. pepper_prob (float, optional): The probability of adding pepper noise to each pixel. If not provided, it will be set to the same value as `salt_prob`. seed: A Python integer or instance of `keras.random.SeedGenerator`. Used to make the behavior of the initializer deterministic. Note that an initializer seeded with an integer or None (unseeded) will produce the same random values across multiple calls. To get different random values across multiple calls, use as seed an instance of `keras.random.SeedGenerator`. Returns: ndarray: The noisy image with salt and pepper noise added. """ if pepper_prob is None: pepper_prob = salt_prob if salt_prob == 0.0 and pepper_prob == 0.0: return image assert ops.dtype(image) == "float32", "Image should be of type float32." noisy_image = ops.copy(image) # Add salt noise salt_mask = keras.random.uniform(ops.shape(image), seed=seed) < salt_prob noisy_image = ops.where(salt_mask, 1.0, noisy_image) # Add pepper noise pepper_mask = keras.random.uniform(ops.shape(image), seed=seed) < pepper_prob noisy_image = ops.where(pepper_mask, 0.0, noisy_image) return noisy_image
[docs] def extend_n_dims(arr, axis, n_dims): """Extend the number of dimensions of an array. Inserts 'n_dims' ones at the specified axis. Args: arr: The input array. axis: The axis at which to insert the new dimensions. n_dims: The number of dimensions to insert. Returns: The array with the extended number of dimensions. Raises: AssertionError: If the axis is out of range. """ assert axis <= ops.ndim(arr), ( "Axis must be less than or equal to the number of dimensions in the array" ) assert axis >= -ops.ndim(arr) - 1, ( "Axis must be greater than or equal to the negative number of dimensions minus 1" ) axis = ops.ndim(arr) + axis + 1 if axis < 0 else axis # Get the current shape of the array shape = ops.shape(arr) # Create the new shape, inserting 'n_dims' ones at the specified axis new_shape = shape[:axis] + (1,) * n_dims + shape[axis:] # Reshape the array to the new shape return ops.reshape(arr, new_shape)
[docs] def func_with_one_batch_dim( func, tensor, n_batch_dims: int, batch_size: int | None = None, func_axis: int | None = None, **kwargs, ): """Wraps a function to apply it to an input tensor with one or more batch dimensions. The function will be executed in parallel on all batch elements. Args: func (function): The function to apply to the image. Will take the `func_axis` output from the function. tensor (Tensor): The input tensor. n_batch_dims (int): The number of batch dimensions in the input tensor. Expects the input to start with n_batch_dims batch dimensions. Defaults to 2. batch_size (int, optional): Integer specifying the size of the batch for each step to execute in parallel. Defaults to None, in which case the function will run everything in parallel. func_axis (int, optional): If `func` returns mulitple outputs, this axis will be returned. **kwargs: Additional keyword arguments to pass to the function. Returns: The output tensor with the same batch dimensions as the input tensor. Raises: ValueError: If the number of batch dimensions is greater than the rank of the input tensor. """ # Extract the shape of the batch dimensions from the input tensor batch_dims = ops.shape(tensor)[:n_batch_dims] # Extract the shape of the remaining (non-batch) dimensions other_dims = ops.shape(tensor)[n_batch_dims:] # Reshape the input tensor to merge all batch dimensions into one reshaped_input = ops.reshape(tensor, [-1, *other_dims]) # Apply the given function to the reshaped input tensor if batch_size is None: reshaped_output = func(reshaped_input, **kwargs) else: reshaped_output = batched_map(func, reshaped_input, batch_size=batch_size) # If the function returns multiple outputs, select the one corresponding to `func_axis` if isinstance(reshaped_output, (tuple, list)): if func_axis is None: raise ValueError( "func_axis must be specified when the function returns multiple outputs." ) reshaped_output = reshaped_output[func_axis] # Extract the shape of the output tensor after applying the function (excluding the batch dim) output_other_dims = ops.shape(reshaped_output)[1:] # Reshape the output tensor to restore the original batch dimensions return ops.reshape(reshaped_output, [*batch_dims, *output_other_dims])
[docs] def matrix_power(matrix, power): """Compute the power of a square matrix. Should match the [numpy](https://numpy.org/doc/stable/reference/generated/numpy.linalg.matrix_power.html) implementation. Parameters: matrix (array-like): A square matrix to be raised to a power. power (int): The exponent to which the matrix is to be raised. Must be a non-negative integer. Returns: array-like: The resulting matrix after raising the input matrix to the specified power. """ if power == 0: return ops.eye(matrix.shape[0]) if power == 1: return matrix if power % 2 == 0: half_power = matrix_power(matrix, power // 2) return ops.matmul(half_power, half_power) return ops.matmul(matrix, matrix_power(matrix, power - 1))
[docs] def boolean_mask(tensor, mask, size=None): """Apply a boolean mask to a tensor. Args: tensor (Tensor): The input tensor. mask (Tensor): The boolean mask to apply. size (int, optional): The size of the output tensor. Only used for Jax backend if you want to trace the function. Defaults to None. Returns: Tensor: The masked tensor. """ if keras.backend.backend() == "jax" and size is not None: import jax.numpy as jnp indices = jnp.where(mask, size=size) # Fixed size allows Jax tracing return tensor[indices] elif keras.backend.backend() == "tensorflow": import tensorflow as tf return tf.boolean_mask(tensor, mask) else: return tensor[mask]
if keras.backend.backend() == "jax": import jax.numpy as jnp def nonzero(x, size=None, fill_value=None): """Return the indices of the elements that are non-zero. Args: x (Tensor): Input tensor. size (int, optional): optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specified size, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded with fill_value. fill_value (int, optional): Value to fill in case there are not enough non-zero elements. Defaults to None. """ return jnp.nonzero(x, size=size, fill_value=fill_value) else:
[docs] def nonzero(x, size=None, fill_value=None): """Return the indices of the elements that are non-zero.""" return ops.nonzero(x)
[docs] def flatten(tensor, start_dim=0, end_dim=-1): """Should be similar to: https://pytorch.org/docs/stable/generated/torch.flatten.html""" # Get the shape of the input tensor old_shape = ops.shape(tensor) # Adjust end_dim if it's negative end_dim = ops.ndim(tensor) + end_dim if end_dim < 0 else end_dim # Create a new shape with -1 in the flattened dimensions new_shape = [*old_shape[:start_dim], -1, *old_shape[end_dim + 1 :]] # Reshape the tensor return ops.reshape(tensor, new_shape)
[docs] def batch_cov(x, rowvar=True, bias=False, ddof=None): """Compute the batch covariance matrices of the input tensor. Args: x (Tensor): Input tensor of shape (..., m, n) where m is the number of features and n is the number of observations. rowvar (bool, optional): If True, each row represents a variable, while each column represents an observation. If False, each column represents a variable, while each row represents an observation. Defaults to True. bias (bool, optional): If True, the biased estimator of the covariance is computed. If False, the unbiased estimator is computed. Defaults to False. ddof (int, optional): Delta degrees of freedom. The divisor used in the calculation is (num_obs - ddof), where num_obs is the number of observations. If ddof is not specified, it is set to 0 if bias is True, and 1 if bias is False. Defaults to None. Returns: Tensor: Batch covariance matrices of shape (..., m, m) if rowvar=True, or (..., n, n) if rowvar=False. """ # Ensure the input has at least 3 dimensions if ops.ndim(x) == 2: x = x[None] if not rowvar: x = ops.moveaxis(x, -1, -2) num_obs = x.shape[-1] if ddof is None: ddof = 0 if bias else 1 # Subtract the mean from each observation meagrid_size_x = ops.mean(x, axis=-1, keepdims=True) x_centered = x - meagrid_size_x # Compute the covariance using einsum cov_matrices = ops.einsum("...ik,...jk->...ij", x_centered, x_centered) / (num_obs - ddof) return cov_matrices
[docs] def patched_map(f, xs, patches: int, jit=True, **batch_kwargs): """Wrapper around `batched_map` for patching. Allows you to specify the number of patches rather than the batch size. """ assert patches > 0, "Number of patches must be greater than 0." if patches == 1: return f(xs, **batch_kwargs) else: length = ops.shape(xs)[0] batch_size = np.ceil(length / patches).astype(int) return batched_map(f, xs, batch_size, jit, **batch_kwargs)
[docs] def batched_map(f, xs, batch_size=None, jit=True, **batch_kwargs): """Map a function over leading array axes. Args: f (callable): Function to apply element-wise over the first axis. xs (Tensor): Values over which to map along the leading axis. batch_size (int, optional): Size of the batch for each step. Defaults to None, in which case the function will be equivalent to `ops.map`, and thus map over the leading axis. jit (bool, optional): If True, use a jitted version of the function for faster batched mapping. Else, loop over the data with the original function. batch_kwargs (dict, optional): Additional keyword arguments (tensors) to batch along with xs. Must have the same first dimension size as xs. Returns: The mapped tensor(s). Idea taken from: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html """ if batch_kwargs is None: batch_kwargs = {} # Ensure all batch kwargs have the same leading dimension as xs. if batch_kwargs: assert all( ops.shape(xs)[0] == ops.shape(v)[0] for v in batch_kwargs.values() if v is not None ), "All batch kwargs must have the same first dimension size as xs." total = ops.shape(xs)[0] # TODO: could be rewritten with ops.cond such that it also works for jit=True. if not jit and batch_size is not None and total <= batch_size: return f(xs, **batch_kwargs) ## Non-jitted version: simply iterate over batches. if not jit: bs = batch_size or 1 # Default batch size to 1 if not specified. outputs = [] for i in range(0, total, bs): idx = slice(i, i + bs) current_kwargs = {k: v[idx] for k, v in batch_kwargs.items()} outputs.append(f(xs[idx], **current_kwargs)) return ops.concatenate(outputs, axis=0) ## Jitted version. # Helper to create the batched function for use with ops.map. def create_batched_f(kw_keys): def batched_f(inputs): x, *kw_values = inputs kw = dict(zip(kw_keys, kw_values)) return f(x, **kw) return batched_f if batch_size is None: batched_f = create_batched_f(list(batch_kwargs.keys())) return ops.map(batched_f, (xs, *batch_kwargs.values())) # Pad and reshape primary tensor. xs_padded = pad_array_to_divisible(xs, batch_size, axis=0) new_shape = (-1, batch_size) + ops.shape(xs_padded)[1:] xs_reshaped = ops.reshape(xs_padded, new_shape) # Pad and reshape batch_kwargs similarly. reshaped_kwargs = {} for k, v in batch_kwargs.items(): if v is None: reshaped_kwargs[k] = None else: v_padded = pad_array_to_divisible(v, batch_size, axis=0) reshaped_kwargs[k] = ops.reshape(v_padded, (-1, batch_size) + ops.shape(v_padded)[1:]) batched_f = create_batched_f(list(reshaped_kwargs.keys())) out = ops.map(batched_f, (xs_reshaped, *reshaped_kwargs.values())) out_reshaped = ops.reshape(out, (-1,) + ops.shape(out)[2:]) return out_reshaped[:total] # Remove any padding added.
if keras.backend.backend() == "jax": # For jit purposes def _get_padding(N, remainder): return N - remainder if remainder != 0 else 0 else: def _get_padding(N, remainder): return ops.where(remainder != 0, N - remainder, 0)
[docs] def pad_array_to_divisible(arr, N, axis=0, mode="constant", pad_value=None): """Pad an array to be divisible by N along the specified axis. Args: arr (Tensor): The input array to pad. N (int): The number to which the length of the specified axis should be divisible. axis (int, optional): The axis along which to pad the array. Defaults to 0. mode (str, optional): The padding mode to use. Defaults to 'constant'. One of `"constant"`, `"edge"`, `"linear_ramp"`, `"maximum"`, `"mean"`, `"median"`, `"minimum"`, `"reflect"`, `"symmetric"`, `"wrap"`, `"empty"`, `"circular"`. Defaults to `"constant"`. pad_value (float, optional): The value to use for padding when mode='constant'. Defaults to None. If mode is not `constant`, this value should be None. Returns: Tensor: The padded array. """ # Get the length of the specified axis length = ops.shape(arr)[axis] # Calculate how much padding is needed for the specified axis remainder = length % N padding = _get_padding(N, remainder) # Create a tuple with (before, after) padding for each axis pad_width = [(0, 0)] * ops.ndim(arr) # No padding for other axes pad_width[axis] = (0, padding) # Padding for the specified axis # Pad the array padded_array = ops.pad(arr, pad_width, mode=mode, constant_values=pad_value) return padded_array
[docs] def interpolate_data(subsampled_data, mask, order=1, axis=-1): """Interpolate subsampled data along a specified axis using `map_coordinates`. Args: subsampled_data (ndarray): The data subsampled along the specified axis. Its shape matches `mask` except along the subsampled axis. mask (ndarray): Boolean array with the same shape as the full data. `True` where data is known. order (int, optional): The order of the spline interpolation. Default is `1`. axis (int, optional): The axis along which the data is subsampled. Default is `-1`. Returns: ndarray: The data interpolated back to the original grid. ValueError: If `mask` does not indicate any missing data or if `mask` has `False` values along multiple axes. """ mask = ops.cast(mask, "bool") # Check that mask indicates subsampled data along the specified axis if ops.sum(mask) == 0: raise ValueError("Mask does not indicate any known data.") if ops.sum(mask) == ops.prod(mask.shape): raise ValueError("Mask does not indicate any missing data.") # make sure subsampled data corresponds with number of 1s in the mask assert len(ops.where(mask)[0]) == ops.prod(subsampled_data.shape), ( "Subsampled data does not match the number of 1s in the mask." ) assert subsampled_data.ndim == 1, "Subsampled data should be a flattened 1D array" assert mask.ndim == 2, "Currently only 2D interpolation supported" # Get the indices of the known data points known_indices = ops.stack(ops.where(mask), axis=-1) # Get the indices of the unknown data points unknown_indices = ops.stack(ops.where(~mask), axis=-1) # map the unknown indices to the new coordinate system # which basically is range(0, mask.shape[axis]) for each axis # but with the gaps removed interp_coords = [] subsampled_shape = [] axis = axis if axis >= 0 else mask.ndim + axis for _axis in range(mask.ndim): length_axis = mask.shape[_axis] if _axis == axis: indices = ops.where( ops.any(~mask, axis=tuple(i for i in range(mask.ndim) if i != _axis)) )[0] # unknown indices indices = map_indices_for_interpolation(indices) subsampled_shape.append(length_axis - len(indices)) else: # broadcast indices indices = ops.arange(length_axis, dtype="float32") subsampled_shape.append(length_axis) interp_coords.append(indices) # create the grid of coordinates for the interpolation interp_coords = ops.meshgrid(*interp_coords, indexing="ij") # should be of shape (mask.ndim, -1) subsampled_data = ops.reshape(subsampled_data, subsampled_shape) # Use map_coordinates to interpolate the data interpolated_data = ops.image.map_coordinates( subsampled_data, interp_coords, order=order, ) interpolated_data = ops.reshape(interpolated_data, -1) # now distirubute the interpolated data back to the original grid output_data = ops.zeros_like(mask, dtype=subsampled_data.dtype) output_data = ops.scatter_update(output_data, unknown_indices, interpolated_data) # Get the values at the known data points known_values = ops.reshape(subsampled_data, (-1,)) output_data = ops.scatter_update( output_data, known_indices, known_values, ) return output_data
[docs] def is_monotonic(array, increasing=True): """Checks if a given 1D array is monotonic. Either entirely non-decreasing or non-increasing. Args: array (ndarray): A 1D numpy array. Returns: bool: True if the array is monotonic, False otherwise. """ # Convert to numpy array to handle general cases array = ops.array(array) # Check if the array is non-decreasing or non-increasing if increasing: return ops.all(array[1:] >= array[:-1]) return ops.all(array[1:] <= array[:-1])
[docs] def map_indices_for_interpolation(indices): """Interpolates a 1D array of indices with gaps. Maps a 1D array of indices with gaps to a 1D array where gaps would be between the integers. Used in the `interpolate_data` function. Args: (indices): A 1D array of indices with gaps. Returns: (indices): mapped to a 1D array where gaps would be between the integers There are two segments here of length 4 and 2 Example: >>> indices = [5, 6, 7, 8, 12, 13, 19] >>> mapped_indices = [5, 5.25, 5.5, 5.75, 8, 8.5, 12.5] """ indices = ops.array(indices, dtype="int32") assert is_monotonic(indices, increasing=True), "Indices should be monotonically increasing" gap_starts = ops.where(indices[1:] - indices[:-1] > 1)[0] gap_starts = ops.concatenate([ops.array([0]), gap_starts + 1], axis=0) gap_lengths = ops.concatenate( [gap_starts[1:] - gap_starts[:-1], ops.array([len(indices) - gap_starts[-1]])], axis=0, ) cumul_gap_lengths = ops.cumsum(gap_lengths) cumul_gap_lengths = ops.concatenate([ops.array([0]), cumul_gap_lengths], axis=0) gap_start_values = ops.take(indices, gap_starts) mapped_starts = gap_start_values - cumul_gap_lengths[:-1] mapped_starts = ops.cast(mapped_starts, "float32") gap_lengths = ops.cast(gap_lengths, "float32") spacing = 1 / (gap_lengths + 1) # Vectorized creation of gap_length entries between the start and end mapped_indices = ops.concatenate( [ (mapped_starts[i] + spacing[i]) + spacing[i] * ops.arange(gap_lengths[i]) for i in range(len(gap_lengths)) ], axis=0, ) mapped_indices -= 1 return mapped_indices
[docs] def stack_volume_data_along_axis(data, batch_axis: int, stack_axis: int, number: int): """Stacks tensor data along a specified stack axis. Stack tensor data along a specified stack axis by splitting it into blocks along the batch axis. Args: data (Tensor): Input tensor to be stacked. batch_axis (int): Axis along which to split the data into blocks. stack_axis (int): Axis along which to stack the blocks. number (int): Number of slices per stack. Returns: Tensor: Reshaped tensor with data stacked along stack_axis. Example: .. code-block:: python import keras data = keras.random.uniform((10, 20, 30)) # stacking along 1st axis with 2 frames per block stacked_data = stack_volume_data_along_axis(data, 0, 1, 2) stacked_data.shape """ blocks = int(ops.ceil(data.shape[batch_axis] / number)) data = pad_array_to_divisible(data, axis=batch_axis, N=blocks, mode="reflect") data = ops.split(data, blocks, axis=batch_axis) data = ops.stack(data, axis=batch_axis) # put batch_axis in front data = ops.transpose( data, ( batch_axis + 1, *range(batch_axis + 1), *range(batch_axis + 2, data.ndim), ), ) data = ops.concatenate(list(data), axis=stack_axis) return data
[docs] def split_volume_data_from_axis(data, batch_axis: int, stack_axis: int, number: int, padding: int): """Splits previously stacked tensor data back to its original shape. This function reverses the operation performed by `stack_volume_data_along_axis`. Args: data (Tensor): Input tensor to be split. batch_axis (int): Axis along which to restore the blocks. stack_axis (int): Axis from which to split the stacked data. number (int): Number of slices per stack. padding (int): Amount of padding to remove from the result. Returns: Tensor: Reshaped tensor with data split back to original format. Example: .. code-block:: python import keras data = keras.random.uniform((20, 10, 30)) split_data = split_volume_data_from_axis(data, 0, 1, 2, 2) split_data.shape """ if data.shape[stack_axis] == 1: # in this case it was a broadcasted axis which does not need to be split return data data = ops.split(data, number, axis=stack_axis) data = ops.stack(data, axis=batch_axis + 1) # combine the unstacked axes (dim 1 and 2) total_block_size = data.shape[batch_axis] * data.shape[batch_axis + 1] data = ops.reshape( data, (*data.shape[:batch_axis], total_block_size, *data.shape[batch_axis + 2 :]), ) # cut off padding if padding > 0: indices = ops.arange(data.shape[batch_axis] - padding + 1) data = ops.take(data, indices, axis=batch_axis) return data
[docs] def compute_required_patch_overlap(image_shape, patch_shape): """Compute required overlap between patches to cover the entire image. Args: image_shape: Tuple of (height, width) patch_shape: Tuple of (patch_height, patch_width) Returns: Tuple of (overlap_y, overlap_x) Or None if there is no overlap that will result in integer number of patches given the image and patch shapes. """ assert len(image_shape) == 2, "image_shape must be a tuple of (height, width)" assert len(patch_shape) == 2, "patch_shape must be a tuple of (patch_height, patch_width)" assert all(image_shape[i] >= patch_shape[i] for i in range(2)), ( "patch_shape must be equal or smaller than image_shape" ) image_y, image_x = image_shape patch_y, patch_x = patch_shape # Calculate number of patches needed in each dimension n_patch_y = max(1, int(ops.ceil(image_y / patch_y))) n_patch_x = max(1, int(ops.ceil(image_x / patch_x))) # Calculate new overlap only if we have more than one patch new_overlap = ( ((patch_y * n_patch_y - image_y) / (n_patch_y - 1) if n_patch_y > 1 else 0), ((patch_x * n_patch_x - image_x) / (n_patch_x - 1) if n_patch_x > 1 else 0), ) # check if can be integer if not all(ops.isclose(new_overlap, ops.round(new_overlap))): return new_overlap = tuple(map(int, new_overlap)) return new_overlap
[docs] def compute_required_patch_shape(image_shape, patch_shape, overlap): """Compute required patch shape to cover the entire image. Compute patch_shape closest to the original patch_shape that will result in integer number of patches given the image and overlap. Args: image_shape: Tuple of (height, width) patch_shape: Tuple of (patch_height, patch_width) overlap: Tuple of (overlap_y, overlap_x) Returns: Tuple of (patch_shape_y, patch_shape_x) or None if there is no patch_shape that will result in integer number of patches given the image and overlap. """ image_y, image_x = image_shape overlap_y, overlap_x = overlap patch_y, patch_x = patch_shape def compute_patch_size(image_size, patch_size, overlap): n_patches = (image_size - overlap) // (patch_size - overlap) new_patch_size = (image_size + (n_patches - 1) * overlap) / n_patches return int(new_patch_size) new_patch_y = compute_patch_size(image_y, patch_y, overlap_y) new_patch_x = compute_patch_size(image_x, patch_x, overlap_x) if (image_y - new_patch_y) % (new_patch_y - overlap_y) != 0 or (image_x - new_patch_x) % ( new_patch_x - overlap_x ) != 0: return None return new_patch_y, new_patch_x
[docs] def check_patches_fit( image_shape: tuple, patch_shape: tuple, overlap: Union[int, Tuple[int, int]] ) -> tuple: """Checks if patches with overlap fit an integer amount in the original image. Args: image_shape: A tuple representing the shape of the original image. patch_size: A tuple representing the shape of the patches. overlap: A float representing the overlap between patches. Returns: A tuple containing a boolean indicating if the patches fit an integer amount in the original image and the new image shape if the patches do not fit. Example: .. code-block:: python image_shape = (10, 10) patch_shape = (4, 4) overlap = (2, 2) patches_fit, new_shape = check_patches_fit(image_shape, patch_shape, overlap) patches_fit new_shape """ if overlap: stride = (np.array(patch_shape) - np.array(overlap)).astype(int) else: stride = (np.array(patch_shape)).astype(int) overlap = (0, 0) stride_y, stride_x = stride patch_y, patch_x = patch_shape image_y, image_x = image_shape if (image_y - patch_y) % stride_y != 0 or (image_x - patch_x) % stride_x != 0: new_shape = ( (image_y - patch_y) // stride_y * stride_y + patch_y, (image_x - patch_x) // stride_x * stride_x + patch_x, ) # new_patch_shape = tuple(map(int, new_patch_shape)) new_patch_shape = compute_required_patch_shape(image_shape, patch_shape, overlap) # Calculate new overlap only if we have more than one patch new_overlap = compute_required_patch_overlap(image_shape, patch_shape) msg = ( "patches with overlap do not fit an integer amount in the original image. " f"Cropping image to closest dimensions that work: {new_shape}. " ) if new_patch_shape is not None: msg += f"Alternatively, change patch shape to: {new_patch_shape} " if new_overlap is not None: msg += f"or change overlap to: {new_overlap}" log.warning(msg) return False, new_shape return True, image_shape
[docs] def images_to_patches( images: keras.KerasTensor, patch_shape: Union[int, Tuple[int, int]], overlap: Union[int, Tuple[int, int]] = None, ) -> keras.KerasTensor: """Creates patches from images. Args: images (Tensor): input images [batch, height, width, channels]. patch_shape (int or tuple, optional): Height and width of patch. Defaults to 4. overlap (int or tuple, optional): Overlap between patches in px. Defaults to None. Returns: patches (Tensor): batch of patches of size: [batch, #patch_y, #patch_x, patch_size_y, patch_size_x, #channels]. Example: .. code-block:: python import keras images = keras.random.uniform((2, 8, 8, 3)) patches = images_to_patches(images, patch_shape=(4, 4), overlap=(2, 2)) patches.shape """ assert len(images.shape) == 4, ( f"input array should have 4 dimensions, but has {len(images.shape)} dimensions" ) assert isinstance(patch_shape, int) or len(patch_shape) == 2, ( f"patch_shape should be an integer or a tuple of length 2, but is {patch_shape}" ) assert isinstance(overlap, (int, type(None))) or len(overlap) == 2, ( f"overlap should be an integer or a tuple of length 2, but is {overlap}" ) batch_size, *image_shape, n_channels = images.shape if isinstance(patch_shape, int): patch_shape = (patch_shape, patch_shape) if isinstance(overlap, int): overlap = (overlap, overlap) patch_size_y, patch_size_x = patch_shape patches_fit, image_shape = check_patches_fit(image_shape, patch_shape, overlap) if not patches_fit: images = images[:, : image_shape[0], : image_shape[1], :] if overlap: stride = (np.array(patch_shape) - np.array(overlap)).astype(int) else: stride = np.array(patch_shape).astype(int) # assert that stride is never smaller than 0 or larger than patch_shape stride = np.maximum(stride, 1) stride = np.minimum(stride, patch_shape) assert np.all(stride <= patch_shape), "Stride should be smaller than patch shape" assert np.all(stride >= 0), "Stride should be larger than 0" ## create patches using ops (this operation is too memory intensive) # patches = ops.image.extract_patches( # images, size=patch_shape, strides=list(stride), padding="valid" # ) ## manual solution instead patches_list = [] for i in range(0, image_shape[0] - patch_size_y + 1, stride[0]): row_patches = [] for j in range(0, image_shape[1] - patch_size_x + 1, stride[1]): patch = images[:, i : i + patch_size_y, j : j + patch_size_x, :] row_patches.append(patch) patches_list.append(ops.stack(row_patches, axis=1)) patches = ops.stack(patches_list, axis=1) _, n_patch_y, n_patch_x, *_ = patches.shape shape = [batch_size, n_patch_y, n_patch_x, patch_size_y, patch_size_x, n_channels] patches = ops.reshape(patches, shape) return patches
[docs] def patches_to_images( patches: keras.KerasTensor, image_shape: tuple, overlap: Union[int, Tuple[int, int]] = None, window_type="average", ) -> keras.KerasTensor: """Reconstructs images from patches. Args: patches (Tensor): Array with batch of patches to convert to batch of images. [batch_size, #patch_y, #patch_x, patch_size_y, patch_size_x, n_channels] image_shape (Tuple): Shape of output image. (height, width, channels) overlap (int or tuple, optional): Overlap between patches in px. Defaults to None. window_type (str, optional): Type of stitching to use. Defaults to 'average'. Options: 'average', 'replace'. Returns: images (Tensor): Reconstructed batch of images from batch of patches. Example: .. code-block:: python import keras patches = keras.random.uniform((2, 3, 3, 4, 4, 3)) images = patches_to_images(patches, image_shape=(8, 8, 3), overlap=(2, 2)) images.shape """ # Input validation assert len(image_shape) == 3, "image_shape must have 3 dimensions: (height, width, channels)." assert len(patches.shape) == 6, ( "patches must have 6 dimensions: [batch_size, n_patch_y, n_patch_x, " "patch_size_y, patch_size_x, n_channels]." ) assert window_type in [ "average", "replace", ], "window_type must be one of 'average', or 'replace'." # Extract dimensions batch_size, n_patches_y, n_patches_x, patch_size_y, patch_size_x, _ = patches.shape dtype = patches.dtype if isinstance(overlap, int): overlap = (overlap, overlap) if overlap is None: overlap = (0, 0) stride_y, stride_x = np.array([patch_size_y, patch_size_x]) - np.array(overlap) # Initialize the output tensor (image) and mask images = keras.ops.zeros((batch_size, *image_shape), dtype=dtype) mask = keras.ops.zeros((batch_size, *image_shape), dtype=dtype) # Loop through each patch for i in range(n_patches_y): for j in range(n_patches_x): start_y = i * stride_y start_x = j * stride_x patch = patches[:, i, j] if window_type == "replace": # Replace pixels directly with the current patch images = keras.ops.slice_update(images, [0, start_y, start_x, 0], patch) else: # Add the current patch to the image images = keras.ops.slice_update( images, [0, start_y, start_x, 0], images[ :, start_y : start_y + patch_size_y, start_x : start_x + patch_size_x, :, ] + patch, ) # Update the mask for averaging mask = keras.ops.slice_update( mask, [0, start_y, start_x, 0], mask[ :, start_y : start_y + patch_size_y, start_x : start_x + patch_size_x, :, ] + 1, ) if window_type == "average": # Normalize overlapping regions if needed images = keras.ops.where(mask > 0, images / mask, images) return images
[docs] def reshape_axis(data, newshape: tuple, axis: int): """Reshape data along axis. Args: data (tensor): input data. newshape (tuple): new shape of data along axis. axis (int): axis to reshape. Example: .. code-block:: python import keras data = keras.random.uniform((3, 4, 5)) newshape = (2, 2) reshaped_data = reshape_axis(data, newshape, axis=1) reshaped_data.shape """ axis = map_negative_indices([axis], data.ndim)[0] shape = list(ops.shape(data)) # list shape = shape[:axis] + list(newshape) + shape[axis + 1 :] return ops.reshape(data, shape)
def _gaussian_filter1d(array, kernel, radius, cval=None, axis=-1, mode="symmetric"): if keras.backend.backend() == "torch": assert mode == "constant", ( "Only constant padding is for sure correct in torch." "Symmetric padding produces different results in torch compared to tensorflow..." ) # Pad input along the specified axis. pad_width = [(0, 0)] * array.ndim pad_width[axis] = (radius, radius) padded = ops.pad(array, pad_width, mode=mode, constant_values=cval) # Move the convolution axis to the last axis. moved = ops.moveaxis(padded, axis, -1) # shape: (..., length) orig_shape = moved.shape length = orig_shape[-1] # Collapse all non-convolution dimensions into the batch. reshaped = ops.reshape(moved, (-1, length, 1)) # shape: (batch, length, in_channels=1) # Reshape kernel for convolution: expected shape (kernel_size, in_channels, out_channels) kernel_size = kernel.shape[0] kernel_reshaped = ops.reshape(kernel, (kernel_size, 1, 1)) # Run the convolution using 'VALID' padding. conv_result = ops.depthwise_conv( reshaped, kernel_reshaped, padding="valid", data_format="channels_last", ) # Reshape the convolved result back to the padded shape. new_length = conv_result.shape[1] conv_result = ops.reshape(conv_result, (*orig_shape[:-1], new_length)) # Move the convolution axis back to its original position. result = ops.moveaxis(conv_result, -1, axis) return result
[docs] def gaussian_filter1d(array, sigma, axis=-1, order=0, mode="symmetric", truncate=4.0, cval=None): """1-D Gaussian filter. Args: array (Tensor): The input array. sigma (float or tuple): Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes. order (int or Tuple[int]): The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0. mode (str, optional): Padding mode for the input image. Default is 'symmetric'. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensoflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter! cval (float, optional): Value to fill past edges of input if mode is 'constant'. Default is None. truncate (float, optional): Truncate the filter at this many standard deviations. Default is 4.0. axes (Tuple[int], optional): If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes. """ # Determine the effective kernel radius and generate the Gaussian kernel radius = int(round(truncate * sigma)) kernel = _gaussian_kernel1d(sigma, order, radius).astype( ops.dtype(array) ) # shape: (kernel_size,) # Reverse the kernel for odd orders to mimic correlation (SciPy behavior) if order % 2: kernel = kernel[::-1] return _gaussian_filter1d(array, kernel, radius, cval, axis, mode)
[docs] def gaussian_filter( array, sigma, order: int | Tuple[int] = 0, mode: str = "symmetric", cval: float | None = None, truncate: float = 4.0, axes: Tuple[int] = None, ): """Multidimensional Gaussian filter. If you want to use this function with jax.jit, you can set: `static_argnames=("truncate", "sigma")` Args: array (Tensor): The input array. sigma (float or tuple): Standard deviation for Gaussian kernel. The standard deviations of the Gaussian filter are given for each axis as a sequence, or as a single number, in which case it is equal for all axes. order (int or Tuple[int]): The order of the filter along each axis is given as a sequence of integers, or as a single number. An order of 0 corresponds to convolution with a Gaussian kernel. A positive order corresponds to convolution with that derivative of a Gaussian. Default is 0. mode (str, optional): Padding mode for the input image. Default is 'symmetric'. See [keras docs](https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad) for all options and [tensoflow docs](https://www.tensorflow.org/api_docs/python/tf/pad) for some examples. Note that the naming differs from scipy.ndimage.gaussian_filter! cval (float, optional): Value to fill past edges of input if mode is 'constant'. Default is None. truncate (float, optional): Truncate the filter at this many standard deviations. Default is 4.0. axes (Tuple[int], optional): If None, input is filtered along all axes. Otherwise, input is filtered along the specified axes. When axes is specified, any tuples used for sigma, order, mode and/or radius must match the length of axes. The ith entry in any of these tuples corresponds to the ith entry in axes. """ axes = _ni_support._check_axes(axes, array.ndim) num_axes = len(axes) orders = _ni_support._normalize_sequence(order, num_axes) sigmas = _ni_support._normalize_sequence(sigma, num_axes) modes = _ni_support._normalize_sequence(mode, num_axes) axes = [(axes[ii], sigmas[ii], orders[ii], modes[ii]) for ii in range(num_axes)] if len(axes) > 0: for ( axis, sigma, order, mode, ) in axes: output = gaussian_filter1d(array, sigma, axis, order, mode, truncate, cval) array = output else: output = array return output
[docs] def resample(x, n_samples, axis=-2, order=1): """Resample tensor along axis. Similar to scipy.signal.resample. Args: x: input tensor. n_samples: number of samples after resampling. axis: axis to resample along. order: interpolation order (1=linear). Returns: Resampled tensor. """ shape = ops.shape(x) rank = len(shape) # Move axis-to-resample to position 1 perm = list(range(rank)) perm_axis1 = perm.copy() perm_axis1[axis], perm_axis1[1] = perm_axis1[1], perm_axis1[axis] x = ops.transpose(x, perm_axis1) # Shape after transpose shape = ops.shape(x) batch_size = shape[0] old_n = shape[1] other_dims = shape[2:] # Create sampling grid batch_coords = ops.arange(batch_size, dtype="float32") # (batch_size,) new_coords = ops.linspace(0.0, ops.cast(old_n - 1, dtype="float32"), n_samples) # (n_samples,) other_coords = [ops.arange(d, dtype="float32") for d in other_dims] # Meshgrid grid = ops.meshgrid( batch_coords, new_coords, *other_coords, indexing="ij" ) # list of (batch_size, n_samples, ...) coord_grid = ops.stack(grid, axis=0) # shape: (rank, batch_size, n_samples, ...) # Interpolate resampled = ops.image.map_coordinates(x, coord_grid, order=order) # Inverse transpose to restore original axis order inv_perm = [perm_axis1.index(i) for i in range(rank)] resampled = ops.transpose(resampled, inv_perm) return resampled
[docs] def fori_loop(lower, upper, body_fun, init_val, disable_jit=False): """For loop allowing for non-jitted for loop with same signature as jax. Args: lower (int): Lower bound of the loop. upper (int): Upper bound of the loop. body_fun (function): Function to be executed in the loop. init_val (any): Initial value for the loop. disable_jit (bool, optional): If True, disables JIT compilation. Defaults to False. """ if not disable_jit: return ops.fori_loop(lower, upper, body_fun, init_val) # Fallback to a non-jitted for loop val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
[docs] def L2(x): """L2 norm of a tensor. Implementation of L2 norm: https://mathworld.wolfram.com/L2-Norm.html """ return ops.sqrt(ops.sum(x**2))
[docs] def linear_sum_assignment(cost): """Greedy linear sum assignment. Args: cost (Tensor): Cost matrix of shape (n, n). Returns: Tuple: Row indices and column indices for assignment. Returns row indices and column indices for assignment. """ n = ops.shape(cost)[0] assigned_true = ops.zeros((n,), dtype="bool") row_ind = [] col_ind = [] for i in range(n): mask = 1.0 - ops.cast(assigned_true, "float32") masked_cost = cost[i] + (1.0 - mask) * 1e6 idx = int(ops.argmin(masked_cost)) row_ind.append(i) col_ind.append(idx) assigned_true = keras.ops.scatter_update(assigned_true, [[idx]], [True]) return np.array(row_ind), np.array(col_ind)
[docs] def sinc(x, eps=keras.config.epsilon()): """Sinc function.""" return ops.sin(x + eps) / (x + eps)
if keras.backend.backend() == "tensorflow": def safe_vectorize( pyfunc, excluded=None, signature=None, ): """Just a wrapper around ops.vectorize. Because tensorflow does not support multiple arguments to ops.vectorize(func)(...) We will just map the function manually. """ def _map(*args): outputs = [] for i in range(ops.shape(args[0])[0]): outputs.append(pyfunc(*[arg[i] for arg in args])) return ops.stack(outputs) return _map else:
[docs] def safe_vectorize(pyfunc, excluded=None, signature=None): """Just a wrapper around ops.vectorize.""" return ops.vectorize(pyfunc, excluded=excluded, signature=signature)
[docs] def apply_along_axis(func1d, axis, arr, *args, **kwargs): """Apply a function to 1D array slices along an axis. Keras implementation of numpy.apply_along_axis using keras.ops.vectorized_map. Args: func1d: A callable function with signature ``func1d(arr, /, *args, **kwargs)`` where ``*args`` and ``**kwargs`` are the additional positional and keyword arguments passed to apply_along_axis. axis: Integer axis along which to apply the function. arr: The array over which to apply the function. *args: Additional positional arguments passed through to func1d. **kwargs: Additional keyword arguments passed through to func1d. Returns: The result of func1d applied along the specified axis. """ # Convert to keras tensor arr = ops.convert_to_tensor(arr) # Get array dimensions num_dims = len(arr.shape) # Canonicalize axis (handle negative indices) if axis < 0: axis = num_dims + axis if axis < 0 or axis >= num_dims: raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}") # Create a wrapper function that applies func1d with the additional arguments def func(slice_arr): return func1d(slice_arr, *args, **kwargs) # Recursively build up vectorized maps following the JAX pattern # For dimensions after the target axis (right side) for i in range(1, num_dims - axis): prev_func = func def make_func(f, dim_offset): def vectorized_func(x): # Move the dimension we want to map over to the front perm = list(range(len(x.shape))) perm[0], perm[dim_offset] = perm[dim_offset], perm[0] x_moved = ops.transpose(x, perm) result = ops.vectorized_map(f, x_moved) # Move the result dimension back if needed if len(result.shape) > 0: result_perm = list(range(len(result.shape))) if len(result_perm) > dim_offset: result_perm[0], result_perm[dim_offset] = ( result_perm[dim_offset], result_perm[0], ) result = ops.transpose(result, result_perm) return result return vectorized_func func = make_func(prev_func, i) # For dimensions before the target axis (left side) for i in range(axis): prev_func = func def make_func(f): return lambda x: ops.vectorized_map(f, x) func = make_func(prev_func) return func(arr)
[docs] def correlate(x, y, mode="full"): """ Complex correlation via splitting real and imaginary parts. Equivalent to np.correlate(x, y, mode). NOTE: this function exists because tensorflow does not support complex correlation. NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input. Args: x: np.ndarray (complex or real) y: np.ndarray (complex or real) mode: "full", "valid", or "same" """ x = ops.convert_to_tensor(x) y = ops.convert_to_tensor(y) is_complex = "complex" in ops.dtype(x) or "complex" in ops.dtype(y) # Split into real and imaginary xr, xi = ops.real(x), ops.imag(x) yr, yi = ops.real(y), ops.imag(y) # Pad to do full correlation pad_left = ops.shape(y)[0] - 1 pad_right = ops.shape(y)[0] - 1 xr = ops.pad(xr, [[pad_left, pad_right]]) xi = ops.pad(xi, [[pad_left, pad_right]]) # Correlation: sum over x[n] * conj(y[n+k]) rr = ops.correlate(xr, yr, mode="valid") ii = ops.correlate(xi, yi, mode="valid") ri = ops.correlate(xr, yi, mode="valid") ir = ops.correlate(xi, yr, mode="valid") real_part = rr + ii imag_part = ir - ri real_part = ops.cast(real_part, "complex64") imag_part = ops.cast(imag_part, "complex64") complex_tensor = real_part + 1j * imag_part # Extract relevant part based on mode full_length = ops.shape(real_part)[0] x_len = ops.shape(x)[0] y_len = ops.shape(y)[0] if mode == "same": # Return output of length max(M, N) target_len = ops.maximum(x_len, y_len) start = ops.floor((full_length - target_len) / 2) start = ops.cast(start, "int32") end = start + target_len complex_tensor = complex_tensor[start:end] elif mode == "valid": # Return output of length max(M, N) - min(M, N) + 1 target_len = ops.maximum(x_len, y_len) - ops.minimum(x_len, y_len) + 1 start = ops.ceil((full_length - target_len) / 2) start = ops.cast(start, "int32") end = start + target_len complex_tensor = complex_tensor[start:end] # For "full" mode, use the entire result (no slicing needed) if is_complex: return complex_tensor else: return ops.real(complex_tensor)