Source code for zea.beamform.beamformer

"""Main beamforming functions for ultrasound imaging."""

import keras
import numpy as np
from keras import ops

from zea.beamform.lens_correction import calculate_lens_corrected_delays
from zea.tensor_ops import safe_vectorize


[docs] def fnum_window_fn_rect(normalized_angle): """Rectangular window function for f-number masking.""" return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
[docs] def fnum_window_fn_hann(normalized_angle): """Hann window function for f-number masking.""" # Use a Hann window function to smoothly transition the mask return ops.where( normalized_angle <= 1.0, 0.5 * (1 + ops.cos(np.pi * normalized_angle)), 0.0, )
[docs] def fnum_window_fn_tukey(normalized_angle, alpha=0.5): """Tukey window function for f-number masking. Args: normalized_angle (ops.Tensor): Normalized angle values in the range [0, 1]. alpha (float, optional): The alpha parameter for the Tukey window. 0.0 corresponds to a rectangular window, 1.0 corresponds to a Hann window. Defaults to 0.5. """ # Use a Tukey window function to smoothly transition the mask normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0) beta = 1.0 - alpha return ops.where( normalized_angle < beta, 1.0, ops.where( normalized_angle < 1.0, 0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))), 0.0, ), )
[docs] def tof_correction( data, flatgrid, t0_delays, tx_apodizations, sound_speed, probe_geometry, initial_times, sampling_frequency, demodulation_frequency, fnum, angles, focus_distances, apply_phase_rotation=False, apply_lens_correction=False, lens_thickness=1e-3, lens_sound_speed=1000, fnum_window_fn=fnum_window_fn_rect, ): """Time-of-flight correction for a flat grid. Args: data (ops.Tensor): Input RF/IQ data of shape `(n_tx, n_ax, n_el, n_ch)`. flatgrid (ops.Tensor): Pixel locations x, y, z of shape `(n_pix, 3)` t0_delays (ops.Tensor): Times at which the elements fire shifted such that the first element fires at t=0 of shape `(n_tx, n_el)` tx_apodizations (ops.Tensor): Transmit apodizations of shape `(n_tx, n_el)` sound_speed (float): Speed-of-sound. probe_geometry (ops.Tensor): Element positions x, y, z of shape (num_samples, 3) initial_times (ops.Tensor): Time-ofsampling_frequencyet per transmission of shape `(n_tx,)`. sampling_frequency (float): Sampling frequency. demodulation_frequency (float): Demodulation frequency. fnum (int, optional): Focus number. Defaults to 1. angles (ops.Tensor): The angles of the plane waves in radians of shape `(n_tx,)` focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)` apply_phase_rotation (bool, optional): Whether to apply phase rotation to time-of-flights. Defaults to False. apply_lens_correction (bool, optional): Whether to apply lens correction to time-of-flights. This makes it slower, but more accurate in the near-field. Defaults to False. lens_thickness (float, optional): Thickness of the lens in meters. Used for lens correction. Defaults to 1e-3. lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used for lens correction Defaults to 1000. fnum_window_fn (callable, optional): F-number function to define the transition from straight in front of the element (fn(0.0)) to the largest angle within the f-number cone (fn(1.0)). The function should be zero for fn(x>1.0). Returns: (ops.Tensor): time-of-flight corrected data with shape: `(n_tx, n_pix, n_el, num_rf_iq_channels)`. """ assert len(data.shape) == 4, ( "The input data should have 4 dimensions, " f"namely num_transmits, num_elements, num_samples, " f"num_rf_iq_channels, got {len(data.shape)} dimensions: ." f"{data.shape}" ) n_tx, n_ax, n_el, _ = ops.shape(data) # Calculate delays # -------------------------------------------------------------------- # txdel: The delay from t=0 to the wavefront reaching the pixel # txdel has shape (n_tx, n_pix) # # rxdel: The delay from the wavefront reaching the pixel to the scattered wave # reaching the transducer element. # rxdel has shape (n_el, n_pix) # -------------------------------------------------------------------- delay_fn = calculate_lens_corrected_delays if apply_lens_correction else calculate_delays txdel, rxdel = delay_fn( flatgrid, t0_delays, tx_apodizations, probe_geometry, initial_times, sampling_frequency, sound_speed, n_tx, n_el, focus_distances, angles, lens_thickness=lens_thickness, lens_sound_speed=lens_sound_speed, ) n_pix = ops.shape(flatgrid)[0] mask = ops.cond( fnum == 0, lambda: ops.ones((n_pix, n_el, 1)), lambda: fnumber_mask(flatgrid, probe_geometry, fnum, fnum_window_fn=fnum_window_fn), ) def _apply_delays(data_tx, txdel): # data_tx is of shape (num_elements, num_samples, 1 or 2) # Take receive delays and add the transmit delays for this transmit # The txdel tensor has one fewer dimensions because the transmit # delays are the same for all dimensions # delays is of shape (n_pix, n_el) delays = rxdel + txdel # Compute the time-of-flight corrected samples for each element # from each pixel of shape (n_pix, n_el, n_ch) tof_tx = apply_delays(data_tx, delays, clip_min=0, clip_max=n_ax - 1) # Apply the mask tof_tx = tof_tx * mask # Phase correction if apply_phase_rotation: tshift = delays[:, :] / sampling_frequency tdemod = flatgrid[:, None, 2] * 2 / sound_speed theta = 2 * np.pi * demodulation_frequency * (tshift - tdemod) tof_tx = _complex_rotate(tof_tx, theta) return tof_tx # Reshape to (n_tx, n_pix, 1) txdel = ops.moveaxis(txdel, 1, 0) txdel = txdel[..., None] return safe_vectorize( _apply_delays, signature="(n_samples,n_el,n_ch),(n_pix,1)->(n_pix,n_el,n_ch)", )(data, txdel)
[docs] def calculate_delays( grid, t0_delays, tx_apodizations, probe_geometry, initial_times, sampling_frequency, sound_speed, n_tx, n_el, focus_distances, polar_angles, **kwargs, ): """Calculates the delays in samples to every pixel in the grid. The delay consists of two components: The transmit delay and the receive delay. The transmit delay is the delay between transmission and the wavefront reaching the pixel. The receive delay is the delay between the wavefront reaching a pixel and the reflections returning to a specific element. Args: grid (Tensor): The pixel coordinates to beamform to of shape `(n_pix, 3)`. t0_delays (Tensor): The transmit delays in seconds of shape `(n_tx, n_el)`, shifted such that the smallest delay is 0. Defaults to None. tx_apodizations (Tensor): The transmit apodizations of shape `(n_tx, n_el)`. probe_geometry (Tensor): The positions of the transducer elements of shape `(n_el, 3)`. initial_times (Tensor): The probe transmit time offsets of shape `(n_tx,)`. sampling_frequency (float): The sampling frequency of the probe in Hz. sound_speed (float): The assumed speed of sound in m/s. focus_distances (Tensor): The focus distances of shape `(n_tx,)`. If the focus distance is set to infinity, the beamformer will assume plane wave transmission. polar_angles (Tensor): The polar angles of the plane waves in radians of shape `(n_tx,)`. Returns: transmit_delays (Tensor): The tensor of transmit delays to every pixel, shape `(n_pix, n_tx)`. receive_delays (Tensor): The tensor of receive delays from every pixel back to the transducer element, shape `(n_pix, n_el)`. """ def _tx_distances(polar_angles, t0_delays, tx_apodizations, focus_distances): return distance_Tx_generic( grid, t0_delays, tx_apodizations, probe_geometry, focus_distances, polar_angles, sound_speed, ) tx_distances = safe_vectorize( _tx_distances, signature="(),(n_el),(n_el),()->(n_pix)", )(polar_angles, t0_delays, tx_apodizations, focus_distances) tx_distances = ops.transpose(tx_distances, (1, 0)) # tx_distances shape is now (n_pix, n_tx) # Compute receive distances def _rx_distances(probe_geometry): return distance_Rx(grid, probe_geometry) rx_distances = safe_vectorize(_rx_distances, signature="(3)->(n_pix)")(probe_geometry) rx_distances = ops.transpose(rx_distances, (1, 0)) # rx_distances shape is now (n_pix, n_el) # Compute the delays [in samples] from the distances # The units here are ([m]/[m/s]-[s])*[1/s] resulting in a unitless quantity # TODO: Add pulse width to transmit delays tx_delays = (tx_distances / sound_speed - initial_times[None]) * sampling_frequency rx_delays = (rx_distances / sound_speed) * sampling_frequency return tx_delays, rx_delays
[docs] def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1): """Applies time delays for a single transmit using linear interpolation. Most delays in d will not be by an integer number of samples, which means we have no measurement for that time instant. This function solves this by finding the sample before and after and interpolating the data to the desired delays in d using linear interpolation. Args: data (ops.Tensor): The RF or IQ data of shape `(n_ax, n_el, n_ch)`. This is the data we are drawing samples from to for each element for each pixel. delays (ops.Tensor): The delays in samples of shape `(n_pix, n_el)`. Contains one delay value for every pixel in the image for every transducer element. clip_min (int, optional): The minimum delay value to use. If set to -1 no clipping is applied. Defaults to -1. clip_max (int, optional): The maximum delay value to use. If set to -1 no clipping is applied. Defaults to -1. Returns: ops.Tensor: The samples received by each transducer element corresponding to the reflections of each pixel in the image of shape `(n_el, n_pix, n_ch)`. """ # Add a dummy channel dimension to the delays tensor to ensure it has the # same number of dimensions as the data. The new shape is (1, n_el, n_pix) delays = delays[..., None] # Get the integer values above and below the exact delay values # Floor to get the integers below # (num_elements, num_pixels, 1) d0 = ops.floor(delays) # Cast to integer to be able to use as indices d0 = ops.cast(d0, "int32") # Add 1 to find the integers above the exact delay values d1 = d0 + 1 # Apply clipping of delays clipping to ensure correct behavior on cpu if clip_min != -1 and clip_max != -1: clip_min = ops.cast(clip_min, d0.dtype) clip_max = ops.cast(clip_max, d0.dtype) d0 = ops.clip(d0, clip_min, clip_max) d1 = ops.clip(d1, clip_min, clip_max) if data.shape[-1] == 2: d0 = ops.concatenate([d0, d0], axis=-1) d1 = ops.concatenate([d1, d1], axis=-1) # Gather pixel values # Here we extract for each transducer element the sample containing the # reflection from each pixel. These are of shape `(n_el, n_pix, n_ch)`. data0 = ops.take_along_axis(data, d0, 0) data1 = ops.take_along_axis(data, d1, 0) # Compute interpolated pixel value d0 = ops.cast(d0, delays.dtype) # Cast to float d1 = ops.cast(d1, delays.dtype) # Cast to float data0 = ops.cast(data0, delays.dtype) # Cast to float data1 = ops.cast(data1, delays.dtype) # Cast to float reflection_samples = (d1 - delays) * data0 + (delays - d0) * data1 return reflection_samples
def _complex_rotate(iq, theta): """Performs a simple phase rotation of I and Q component. Args: iq (ops.Tensor): The iq data of shape `(..., 2)`. theta (float): The complex angle to rotate by. Returns: Tensor: The rotated tensor of shape `(..., 2)`. """ # assert iq.shape[-1] == 2, ( # "The last dimension of the input tensor should be 2, " # f"got {iq.shape[-1]} dimensions and shape {iq.shape}." # ) # Select i and q channels i = iq[..., 0] q = iq[..., 1] # Compute rotated components ir = i * ops.cos(theta) - q * ops.sin(theta) qr = q * ops.cos(theta) + i * ops.sin(theta) # Reintroduce channel dimension ir = ir[..., None] qr = qr[..., None] return ops.concatenate([ir, qr], -1)
[docs] def distance_Rx(grid, probe_geometry): """Computes distance to user-defined pixels from elements. Expects all inputs to be numpy arrays specified in SI units. Args: grid (ops.Tensor): Pixel positions in x,y,z of shape `(n_pix, 3)`. probe_geometry (ops.Tensor): Element positions in x,y,z of shape `(n_el, 3)`. Returns: dist (ops.Tensor): Distance from each pixel to each element of shape `(n_pix, n_el)`. """ # Get norm of distance vector between elements and pixels via broadcasting dist = ops.linalg.norm(grid - probe_geometry[None, ...], axis=-1) return dist
[docs] def distance_Tx_generic( grid, t0_delays, tx_apodization, probe_geometry, focus_distance, polar_angle, sound_speed=1540, ): """Generic transmit distance calculation. Computes distance to user-defined pixels for generic transmits based on the t0_delays. Args: grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape `(n_pix, 3)` t0_delays (ops.Tensor): The transmit delays in seconds of shape `(n_el,)`, shifted such that the smallest delay is 0. Defaults to None. tx_apodization (ops.Tensor): The transmit apodizations of shape `(n_el,)`. probe_geometry (ops.Tensor): The positions of the transducer elements of shape `(n_el, 3)`. sound_speed (float): The speed of sound in m/s. Defaults to 1540. Returns: Tensor: Distance from each pixel to each element in meters of shape `(n_pix,)` """ # Get the individual x, y, and z components of the pixel coordinates x = grid[:, 0] y = grid[:, 1] z = grid[:, 2] # Reshape x, y, and z to shape (n_pix, 1) x = x[..., None] y = y[..., None] z = z[..., None] # Get the individual x, y, and z coordinates of the elements and add a # dummy dimension at the beginning to shape (1, n_el). ele_x = probe_geometry[None, :, 0] ele_y = probe_geometry[None, :, 1] ele_z = probe_geometry[None, :, 2] # Compute the differences dx, dy, and dz of shape (n_pix, n_el) dx = x - ele_x dy = y - ele_y dz = z - ele_z # Define an infinite offset for elements that do not fire to not consider them in # the transmit distance calculation. offset = ops.where(tx_apodization == 0, np.inf, 0.0) # Compute the distance between the elements and the pixels of shape # (n_pix, n_el) dist = t0_delays[None] * sound_speed + ops.sqrt(dx**2 + dy**2 + dz**2) # Compute the z-coordinate of the focal point focal_z = ops.cos(polar_angle) * focus_distance # Compute the effective distance of the pixels to the wavefront by computing the # largest distance over all the elements when the pixel is behind the virtual # source and the smallest distance otherwise. dist = ops.where( ops.cast(ops.sign(focus_distance), "float32") * (grid[:, 2] - focal_z) <= 0.0, ops.min(dist + offset[None], 1), ops.max(dist - offset[None], 1), ) return dist
[docs] def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn): """Apodization mask for the receive beamformer. Computes a mask to disregard pixels outside of the vision cone of a transducer element. Transducer elements can only accurately measure signals within some range of incidence angles. Waves coming in from the side do not register correctly leading to a worse image. Args: flatgrid (ops.Tensor): The flattened image grid `(n_pix, 3)`. probe_geometry (ops.Tensor): The transducer element positions of shape `(n_el, 3)`. f_number (int): The receive f-number. Set to zero to not use masking and return 1. (The f-number is the ratio between distance from the transducer and the size of the aperture below which transducer elements contribute to the signal for a pixel.). fnum_window_fn (callable): F-number function to define the transition from straight in front of the element (fn(0.0)) to the largest angle within the f-number cone (fn(1.0)). The function should be zero for fn(x>1.0). Returns: Tensor: Mask of shape `(n_pix, n_el, 1)` """ grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None] grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1) grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6) alpha = ops.arccos(grid_relative_to_probe_z) # The f-number is fnum = z/aperture = 1/(2 * tan(alpha)) # Rearranging gives us alpha = arctan(1/(2 * fnum)) # We can use this to compute the maximum angle alpha that is allowed max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon())) normalized_angle = alpha / max_alpha mask = fnum_window_fn(normalized_angle) # Add dummy channel dimension mask = mask[..., None] return mask