Source code for zea.data.augmentations

"""Augmentation layers for ultrasound data."""

import keras
import numpy as np
from keras import layers, ops

from zea.tensor_ops import is_jax_prng_key, split_seed


[docs] class RandomCircleInclusion(layers.Layer): """ Adds a circular inclusion to the image, optionally at random locations. Since this can accept N-dimensional inputs, you'll need to specify your ``circle_axes`` -- these are the axes onto which a circle will be drawn. This circle will then be broadcast along the remaining dimensions. You can then optionally specify whether there is a batch dim, and whether the circles should be located randomly across that batch. For example, if you have a batch of videos, e.g. of shape [batch, frame, height, width], then you might want to specify ``circle_axes=(2, 3)``, and ``randomize_location_across_batch=True``. This would result in a circle that is located in the same place per video, but different locations for different videos. Once your method has recovered the circles, you can evaluate them using the ``evaluate_recovered_circle_accuracy()`` method, which will expect an input shape matching your inputs to ``call()``. """ def __init__( self, radius: int, fill_value: float = 1.0, circle_axes: tuple[int, int] = (1, 2), with_batch_dim=True, return_centers=False, recovery_threshold=0.1, randomize_location_across_batch=True, seed=None, **kwargs, ): """ Initialize RandomCircleInclusion. Args: radius (int): Radius of the circle to include. fill_value (float): Value to fill inside the circle. circle_axes (tuple[int, int]): Axes along which to draw the circle (height, width). with_batch_dim (bool): Whether input has a batch dimension. return_centers (bool): Whether to return circle centers along with images. recovery_threshold (float): Threshold for considering a pixel as recovered. randomize_location_across_batch (bool): If True, randomize circle location per batch element. seed (Any): Optional random seed for reproducibility. **kwargs: Additional keyword arguments for the parent Layer. """ super().__init__(**kwargs) self.radius = radius self.fill_value = fill_value self.circle_axes = circle_axes self.with_batch_dim = with_batch_dim self.return_centers = return_centers self.recovery_threshold = recovery_threshold self.randomize_location_across_batch = randomize_location_across_batch self.seed = seed self._axis1 = None self._axis2 = None self._perm = None self._inv_perm = None self._static_shape = None self._static_batch = None self._static_h = None self._static_w = None self._static_flat_batch = 1
[docs] def build(self, input_shape): """ Build the layer and compute static shape and permutation info. Args: input_shape (tuple): Shape of the input tensor. """ rank = len(input_shape) - 1 if self.with_batch_dim else len(input_shape) a1, a2 = self.circle_axes if self.with_batch_dim and (a1 == 0 or a2 == 0): raise ValueError("The circle axes should not be a batch dim") if a1 < 0: a1 += rank elif a1 > 0 and self.with_batch_dim: a1 -= 1 if a2 < 0: a2 += rank elif a2 > 0 and self.with_batch_dim: a2 -= 1 if not (0 <= a1 < rank and 0 <= a2 < rank): raise ValueError(f"circle_axes {self.circle_axes} out of range for rank {rank}") if a1 == a2: raise ValueError("circle_axes must be two distinct axes") self._axis1, self._axis2 = a1, a2 all_axes = list(range(rank)) other_axes = [ax for ax in all_axes if ax not in (a1, a2)] self._perm = other_axes + [a1, a2] inv = [0] * rank for i, ax in enumerate(self._perm): inv[ax] = i self._inv_perm = inv if self.with_batch_dim: input_shape = input_shape[1:] # ignore batch dim permuted_shape = [input_shape[ax] for ax in self._perm] if len(permuted_shape) > 2: self._static_flat_batch = int(np.prod(permuted_shape[:-2])) self._static_h = int(permuted_shape[-2]) self._static_w = int(permuted_shape[-1]) self._static_shape = tuple(permuted_shape) super().build(input_shape)
[docs] def compute_output_shape(self, input_shape): """ Compute output shape for the layer. Args: input_shape (tuple): Shape of the input tensor. Returns: tuple: The output shape (same as input). """ return input_shape
def _permute_axes_to_circle_last(self, x): """ Permute axes so that circle axes are last. Args: x (Tensor): Input tensor. Returns: Tensor: Tensor with circle axes as the last two dimensions. """ return ops.transpose(x, axes=self._perm) def _flatten_batch_and_other_dims(self, x): """ Flatten all axes except the last two (circle axes). Args: x (Tensor): Input tensor with circle axes last. Returns: tuple: (reshaped tensor, flat batch size, height, width). """ shape = x.shape flat_batch = int(np.prod(shape[:-2])) if len(shape) > 2 else 1 h, w = shape[-2], shape[-1] return ops.reshape(x, [flat_batch, h, w]), flat_batch, h, w def _make_circle_mask(self, centers, h, w, radius, dtype): """ Create a mask for each center (batch, h, w) using Keras ops. Args: centers (Tensor): Tensor of shape (batch, 2) with circle centers. h (int): Height of the image. w (int): Width of the image. radius (int): Radius of the circle. dtype (str or dtype): Data type for the mask. Returns: Tensor: Mask of shape (batch, h, w). """ Y = ops.arange(h) X = ops.arange(w) Y, X = ops.meshgrid(Y, X, indexing="ij") Y = ops.expand_dims(Y, 0) # (1, h, w) X = ops.expand_dims(X, 0) # (1, h, w) # cx = ops.cast(centers[:, 0], "float32")[:, None, None] # cy = ops.cast(centers[:, 1], "float32")[:, None, None] cx = centers[:, 0][:, None, None] cy = centers[:, 1][:, None, None] dist2 = (X - cx) ** 2 + (Y - cy) ** 2 mask = ops.cast(dist2 <= radius**2, dtype) return mask
[docs] def call(self, x, seed=None): """ Apply the random circle inclusion augmentation. Args: x (Tensor): Input tensor. seed (Any, optional): Optional random seed for reproducibility. Returns: Tensor or tuple: Augmented images, and optionally the circle centers if return_centers is True. """ if keras.backend.backend() == "jax" and not is_jax_prng_key(seed): raise NotImplementedError( "jax.random.key() is not supported, please use jax.random.PRNGKey()" ) seed = seed if seed is not None else self.seed if self.with_batch_dim: x_is_symbolic_tensor = not isinstance(ops.shape(x)[0], int) if x_is_symbolic_tensor: if self.randomize_location_across_batch: imgs, centers = ops.map(lambda arg: self._call(arg, seed), x) else: raise NotImplementedError( "You cannot fix circle locations across while using" + "RandomCircleInclusion as a dataset augmentation, " + "since samples in a batch are handled independently." ) else: if self.randomize_location_across_batch: batch_size = ops.shape(x)[0] seeds = split_seed(seed, batch_size) if all(seed is seeds[0] for seed in seeds): imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x) else: imgs, centers = ops.map( lambda args: self._call(args[0], args[1]), (x, seeds) ) else: imgs, centers = ops.map(lambda arg: self._call(arg, seed), x) else: imgs, centers = self._call(x, seed) if self.return_centers: return imgs, centers else: return imgs
def _call(self, x, seed): """ Internal method to apply the augmentation to a single image. Args: x (Tensor): Input image tensor with circle axes last. seed (Any): Random seed for circle location. Returns: tuple: (augmented image, center coordinates). """ x = self._permute_axes_to_circle_last(x) flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x) def _draw_circle_2d(img2d): cx = ops.cast( keras.random.uniform((), self.radius, w - self.radius, seed=seed), "int32", ) new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent cy = ops.cast( keras.random.uniform((), self.radius, h - self.radius, seed=new_seed), "int32", ) mask = self._make_circle_mask( ops.stack([cx, cy])[None, :], h, w, self.radius, img2d.dtype )[0] img_aug = img2d * (1 - mask) + self.fill_value * mask center = ops.stack([cx, cy]) return img_aug, center aug_imgs, centers = ops.vectorized_map(_draw_circle_2d, flat) aug_imgs = ops.reshape(aug_imgs, x.shape) aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm) centers_shape = [2] if flat_batch_size == 1 else [flat_batch_size, 2] centers = ops.reshape(centers, centers_shape) return (aug_imgs, centers)
[docs] def get_config(self): """ Get layer configuration for serialization. Returns: dict: Dictionary of layer configuration. """ cfg = super().get_config() cfg.update( { "radius": self.radius, "fill_value": self.fill_value, "circle_axes": self.circle_axes, "return_centers": self.return_centers, } ) return cfg
[docs] def evaluate_recovered_circle_accuracy( self, images, centers, recovery_threshold, fill_value=None ): """ Evaluate the percentage of the true circle that has been recovered in the images. Args: images (Tensor): Tensor of images (any shape, with circle axes as specified). centers (Tensor): Tensor of circle centers (matching batch size). recovery_threshold (float): Threshold for considering a pixel as recovered. fill_value (float, optional): Optionally override fill_value for cases where image range has changed. Returns: Tensor: Percentage recovered for each circle (shape: [num_circles]). """ fill_value = fill_value or self.fill_value def _evaluate_recovered_circle_accuracy(image, center): image_perm = self._permute_axes_to_circle_last(image) h, w = image_perm.shape[-2], image_perm.shape[-1] flat_image, _, _, _ = self._flatten_batch_and_other_dims(image_perm) flat_center = ops.reshape(center, [-1, 2]) mask = self._make_circle_mask(flat_center, h, w, self.radius, flat_image.dtype) diff = ops.abs(flat_image - fill_value) recovered = ops.cast(diff <= recovery_threshold, flat_image.dtype) * mask recovered_sum = ops.sum(recovered, axis=[1, 2]) mask_sum = ops.sum(mask, axis=[1, 2]) percent_recovered = recovered_sum / (mask_sum + 1e-8) return percent_recovered if self.with_batch_dim: return ops.vectorized_map( lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]), (images, centers), )[..., 0] else: return _evaluate_recovered_circle_accuracy(images, centers)