Source code for zea.beamform.pfield

"""Pressure field computation for ultrasound imaging.

This module provides routines for automatic computation of the acoustic pressure field
used for compounding multiple transmit (Tx) events in ultrasound imaging.

The pressure field is computed by simulating the acoustic response of the probe and
medium for each transmit event. The computation involves:

- Subdividing each probe element into sub-elements to satisfy the Fraunhofer approximation.
- Calculating the distances and angles between each grid point and each sub-element.
- Computing the frequency response of the probe and the pulse spectrum.
- Summing the contributions from all relevant frequencies, taking into account
  transmit delays, apodization, and directivity.
- Optionally normalizing and thresholding the resulting field for use in
  transmit compounding or adaptive beamforming.

The main entry point is :func:`compute_pfield`, which returns a normalized pressure
field array for all transmit events.

"""

import keras
import numpy as np
from keras import ops

from zea import log
from zea.internal.cache import cache_output
from zea.tensor_ops import sinc


def _abs_sinc(x):
    return sinc(ops.abs(x))


@cache_output(verbose=True)
def compute_pfield(
    sound_speed,
    center_frequency,
    bandwidth_percent,
    n_el,
    probe_geometry,
    tx_apodizations,
    grid,
    t0_delays,
    frequency_step=4,
    db_thresh=-1,
    downsample=10,
    downmix=4,
    alpha=1,
    percentile=10,
    norm=True,
    verbose=True,
):
    """Compute the pressure field for ultrasound imaging.

    Args:
        sound_speed (float): Speed of sound in the medium.
        center_frequency (float): Center frequency of the probe in Hz.
        bandwidth_percent (float): Bandwidth of the probe, pulse-echo 6dB fractional bandwidth (%)
        n_el (int): Number of elements in the probe.
        probe_geometry (array): Geometry of the probe elements.
        tx_apodizations (array): Transmit apodization values.
        grid (array): Grid points where the pressure field is computed
            of shape (grid_size_z, grid_size_x, 3).
        t0_delays (array): Transmit delays for each transmit event.
        frequency_step (int, optional): Frequency step. Default is 4.
            Higher is faster but less accurate.
        db_thresh (int, optional): dB threshold. Default is -1.
            Higher is faster but less accurate.
        downsample (int, optional): Downsample the grid for faster computation.
            Default is 10. Higher is faster but less accurate.
        downmix (int, optional): Downmixing the frequency to facilitate a smaller grid.
            Default is 4. Higher requires lower number of grid points but is less accurate.
        alpha (float, optional): Exponent to 'sharpen or smooth' the weighting. Higher is sharper.
            Only works when norm is True. Default is 1.
        percentile (int, optional): minimum percentile threshold to keep in the weighting.
            Only works when norm is True. Higher is more aggressive. Default is 10.
        norm (bool, optional): per pixel normalization (True) or unnormalized (False)
        verbose (bool, optional): Whether to print progress.

    Returns:
        ops.array: The (normalized) pressure field (across tx events)
            of shape (n_tx, grid_size_z, grid_size_x).
    """
    # medium params
    alpha_db = 0  # currently we ignore attenuation in the compounding

    # probe params
    center_frequency = center_frequency / downmix  # downmixing the frequency

    # pulse params
    num_waveforms = 1  # number of waveforms in the pulse

    # array params
    probe_geometry = ops.convert_to_tensor(probe_geometry, dtype="float32")

    pitch = probe_geometry[1, 0] - probe_geometry[0, 0]  # element pitch

    kerf = 0.1 * pitch  # for now this is hardcoded
    element_width = pitch - kerf

    num_transmits = len(tx_apodizations)

    # %------------------------------------%
    # % POINT LOCATIONS, DISTANCES & GRIDS %
    # %------------------------------------%

    # subdivide elements into sub elements or not? (to satisfy Fraunhofer approximation)
    lambda_min = sound_speed / (center_frequency * (1 + bandwidth_percent / 200))
    num_sub_elements = ops.ceil(element_width / lambda_min)

    x_orig = ops.convert_to_tensor(grid[:, :, 0], dtype="float32")
    z_orig = ops.convert_to_tensor(grid[:, :, 2], dtype="float32")

    size_orig = ops.shape(x_orig)

    # Nearest-neighbor downsampling the grid
    x = x_orig[::downsample, ::downsample]
    z = z_orig[::downsample, ::downsample]
    size_downsampled = ops.shape(x)

    # Coordinates of the points where pressure is needed
    x = ops.reshape(x, (-1,))
    z = ops.reshape(z, (-1,))

    # Centers of the transducer elements (x- and z-coordinates)
    xe = (ops.arange(0.0, n_el) - (n_el - 1) / 2) * pitch
    ze = ops.zeros(n_el)
    the = ops.zeros(n_el)

    # Centroids of the sub-elements
    seg_length = element_width / num_sub_elements
    tmp = (
        -element_width / 2
        + seg_length / 2
        + ops.arange(0, num_sub_elements, dtype=seg_length.dtype) * seg_length
    )
    xi = tmp
    zi = ops.zeros((int(num_sub_elements),))

    # Distances between the points and the transducer elements
    x_expanded = x[:, None, None]
    xi_expanded = xi[None, :, None]
    xe_expanded = xe[None, None, :]

    dxi = x_expanded - xi_expanded - xe_expanded

    z_expanded = z[:, None, None]
    zi_expanded = zi[None, :, None]
    ze_expanded = ze[None, None, :]

    d2 = dxi**2 + (z_expanded - zi_expanded - ze_expanded) ** 2
    r = ops.sqrt(d2)
    r_flat = ops.reshape(r, (-1,))

    # Angle between the normal to the transducer and the line joining
    # the point and the transducer
    eps = keras.config.epsilon()
    theta = ops.arcsin((dxi + eps) / (ops.sqrt(d2) + eps)) - the
    sin_theta = ops.sin(theta)

    pulse_width = num_waveforms / center_frequency  # temporal pulse width
    wc = 2 * np.pi * center_frequency

    def pulse_spectrum(w):
        imag = _abs_sinc(pulse_width * (w - wc) / 2) - _abs_sinc(pulse_width * (w + wc) / 2)
        return 1j * ops.cast(imag, "complex64")

    # FREQUENCY RESPONSE of the ensemble PZT + probe
    w_bandwidth = bandwidth_percent * wc / 100  # angular frequency bandwidth
    p_shape = ops.log(126) / ops.log(eps + 2 * wc / w_bandwidth)

    def probe_spectrum(w):
        # Calculate the normalized frequency difference
        freq_diff = ops.abs(w - wc)
        # Calculate the denominator for normalization
        denom = (w_bandwidth / 2) / (ops.log(2) ** (1 / p_shape))
        # Raise the normalized difference to the power of p_shape
        exponent = (freq_diff / denom) ** p_shape
        # Apply the negative sign and exponential
        return ops.exp(-exponent)

    p_list = []

    if verbose:
        log.info("Computing pressure field for all transmits")
        progbar = keras.utils.Progbar(num_transmits, unit_name="transmits")
    for j in range(0, num_transmits):
        # delays and apodization of transmit event
        delays_tx = ops.convert_to_tensor(t0_delays[j], dtype="float32")
        idx_nan = ops.isnan(delays_tx)
        delays_tx = ops.where(idx_nan, 0, delays_tx)

        tx_apodization = ops.convert_to_tensor(tx_apodizations[j])
        idx_nan = ops.isnan(tx_apodization)
        tx_apodization = ops.where(idx_nan, 0, tx_apodization)
        tx_apodization = ops.squeeze(tx_apodization)

        # The frequency response is a pulse-echo (transmit + receive) response.
        # The spectrum of the pulse (pulse_spectrum) will be then multiplied
        # by the frequency-domain tapering window of the transducer (probe_spectrum)
        # The frequency step df is chosen to avoid interferences due to
        # inadequate discretization.
        # df = frequency step (must be sufficiently small):
        # One has exp[-i(k r + w delay)] = exp[-2i pi(f r/c + f delay)] in the Eq.
        # One wants: the phase increment 2pi(df r/c + df delay) be < 2pi.
        # Therefore: df < 1/(r/c + delay).

        delays_tx_flat = ops.reshape(delays_tx, (-1,))

        df = 1 / (ops.max(r_flat / sound_speed) + ops.max(delays_tx_flat))
        df = frequency_step * df

        # FREQUENCY SAMPLES
        num_freq = 2 * ops.cast(ops.ceil(center_frequency / df), "int32") + 1
        freq = ops.linspace(0, 2 * center_frequency, num_freq)
        df = freq[1]

        # keep the significant components only by using db_thresh
        spectrum = ops.abs(
            pulse_spectrum(2 * np.pi * freq)
            * ops.cast(probe_spectrum(2 * np.pi * freq), "complex64")
        )
        gain_db = 20 * ops.log10(eps + spectrum / (ops.max(spectrum)))
        idx = gain_db > db_thresh

        freq = freq[idx]

        pulse_spect = pulse_spectrum(2 * np.pi * freq)
        probe_spect = probe_spectrum(2 * np.pi * freq)

        # Exponential arrays of size [numel(x) n_el num_sub_elements]
        kw = 2 * np.pi * freq[0] / sound_speed
        kwa = alpha_db / 8.69 * freq[0] / 1e6 * 1e2

        r_complex = ops.cast(r, dtype="complex64")
        kwa = ops.cast(kwa, dtype="complex64")
        mod_out = ops.cast(ops.mod(kw * r, 2 * np.pi), dtype="complex64")
        exp_arr = ops.exp(-kwa * r_complex + 1j * mod_out)

        # Exponential array for the increment wavenumber dk
        dkw = 2 * np.pi * df / sound_speed
        dkwa = alpha_db / 8.69 * df / 1e6 * 1e2
        dkw = ops.cast(dkw, dtype="complex64")
        dkwa = ops.cast(dkwa, dtype="complex64")

        exp_df = ops.exp((-dkwa + 1j * dkw) * r_complex)

        exp_arr = exp_arr / ops.sqrt(r_complex)
        exp_arr = exp_arr * ops.cast(ops.min(ops.sqrt(r)), "complex64")  # normalize the field

        center_wavenumber = 2 * np.pi * center_frequency / sound_speed
        directivity = _abs_sinc(center_wavenumber * seg_length / 2 * sin_theta)
        exp_arr = exp_arr * ops.cast(directivity, "complex64")

        # Render pressure field for all relevant frequencies and sum them up
        rp = _pfield_freq_loop(
            freq,
            sound_speed,
            delays_tx,
            tx_apodization,
            exp_arr,
            exp_df,
            pulse_spect,
            probe_spect,
            z,
        )

        # RMS acoustic pressure
        p = ops.reshape(ops.sqrt(rp), size_downsampled)

        # resize p to exactly the original grid size
        p = ops.squeeze(ops.image.resize(p[..., None], size_orig, interpolation="nearest"), axis=-1)

        p_list.append(p)

        if verbose:
            progbar.add(1)

    p_arr = ops.convert_to_tensor(p_list)
    p_arr = ops.where(
        ops.isnan(p_arr), 0, p_arr
    )  # TODO: this is necessary for Jax / TF somehow. not sure why (not for torch)

    if norm:
        p_norm = normalize_pressure_field(p_arr, alpha=alpha, percentile=percentile)
    else:
        p_norm = p_arr

    return p_norm


[docs] def normalize_pressure_field(pfield, alpha: float = 1.0, percentile: float = 10.0): """ Normalize the input array of intensities by zeroing out values below a given percentile. Args: pfield (array): The unnormalized pressure field array of shape (n_tx, grid_size_z, grid_size_x). alpha (float, optional): Exponent to 'sharpen or smooth' the weighting. Higher values result in sharper weighting. Default is 1.0. percentile (int, optional): minimum percentile threshold to keep in the weighting. Higher is more aggressive. Default is 10. Returns: ops.array: Normalized intensity array. """ # Convert percentile to quantile (0–1 range) q = percentile / 100.0 # Compute per-transmitter quantile thresholds threshold = ops.quantile(pfield, q, axis=(1, 2), keepdims=True) # Zero out values below the threshold pfield = ops.where(pfield < threshold, 0, pfield) # Sharpen the beam pfield = ops.power(pfield, alpha) # Normalize over transmit events (axis=0) p_norm = pfield / (keras.config.epsilon() + ops.sum(pfield, axis=0, keepdims=True)) return p_norm
def _pfield_freq_step( k, freq, sound_speed, delays_tx, tx_apodization, rp_mono, pulse_spect, probe_spect, z, ): """ Calculates the pressure field for a single frequency step. Args: k (int): Frequency index. freq (list): List of frequencies. sound_speed (float): Speed of sound. delays_tx (list): List of transmit delays. tx_apodization (list): List of transmit apodization values (complex64). rp_mono: (Tensor): Per-element, per-field-point complex pressure response (including directivity and propagation effects) at the current frequency sample. pulse_spect (list): List of pulse spectra. probe_spect (list): List of probe spectra (complex64). z (list): List of z-coordinates. Returns: rp_k (Tensor): Pressure field for this frequency. """ kw = 2 * np.pi * freq[k] / sound_speed del_apod = ops.exp(1j * ops.cast(kw * sound_speed * delays_tx, "complex64")) * tx_apodization rp_k = ops.matmul(rp_mono, del_apod) * pulse_spect[k] * probe_spect[k] rp_k = ops.where(z < 0, 0, rp_k) return ops.abs(rp_k) ** 2 def _pfield_freq_loop( freq, sound_speed, delays_tx, tx_apodization, exp_arr, exp_df, pulse_spect, probe_spect, z, ): """Calculates the pressure field using frequency loop method. Args: freq (list): List of frequencies. sound_speed (float): Speed of sound. delays_tx (list): List of transmit delays. tx_apodization (list): List of transmit apodization values. exp_arr (list): List of complex exponentials. exp_df (list): List of complex exponential frequency shifts. pulse_spect (list): List of pulse spectra. probe_spect (list): List of probe spectra. z (list): List of z-coordinates. Returns: (Tensor): Pressure field. """ tx_apodization = ops.cast(tx_apodization, "complex64") probe_spect = ops.cast(probe_spect, "complex64") rp_mono = exp_arr rp = 0 for k in range(len(freq)): if k > 0: rp_mono *= exp_df rp_k = _pfield_freq_step( k, freq, sound_speed, delays_tx, tx_apodization, ops.mean(rp_mono, axis=1), # avg over sub-elements pulse_spect, probe_spect, z, ) rp += rp_k return rp