"""Frequency domain ultrasound simulator.
The simulator works in the frequency domain (RFFT domain) and simulates RF data as a superposition
of scatterer responses. Every scatterer has a location and a magnitude.
To use it in your code, simply call the :func:`simulate_rf` function with the desired
transmit scheme parameters and scatterers. To simulate a sequence of multiple frames,
you can call :func:`simulate_rf` repeatedly with different scatterer positions and magnitudes
and then stack the results.
Example usage
^^^^^^^^^^^^^
A simple example of simulating RF data with a single scatterer at the center of the probe. For a
more in depth example see the notebook: :doc:`../notebooks/data/zea_simulation_example`.
.. code-block:: python
raw_data = simulate_rf(
scatterer_positions=np.array([[0, 0, 20e-3]]),
scatterer_magnitudes=np.array([1.0]),
probe_geometry=np.stack(
[np.linspace(-20e-3, 20e-3, 64), np.zeros(64), np.zeros(64)], axis=-1
),
apply_lens_correction=True,
lens_thickness=1e-3,
lens_sound_speed=1000,
sound_speed=1540,
n_ax=1024,
center_frequency=5e6,
sampling_frequency=20e6,
t0_delays=np.zeros((1, 64)),
initial_times=np.zeros(1),
element_width=0.2e-3,
attenuation_coef=0.5,
tx_apodizations=np.ones((1, 64)),
)
"""
import numpy as np
from keras import ops
from zea.beamform.lens_correction import compute_lens_corrected_travel_times
[docs]
def simulate_rf(
scatterer_positions,
scatterer_magnitudes,
probe_geometry,
apply_lens_correction,
lens_thickness,
lens_sound_speed,
sound_speed,
n_ax,
center_frequency,
sampling_frequency,
t0_delays,
initial_times,
element_width,
attenuation_coef,
tx_apodizations,
):
"""
Simulates RF data for a given set of scatterers.
Args:
scatterer_positions (array-like): The positions of the scatterers [m] of shape (n_scat, 3).
scatterer_magnitudes (array-like): The magnitudes of the scatterers of shape (n_scat,).
probe_geometry (array-like): The geometry of the probe [m] of shape (n_el, 3).
apply_lens_correction (bool): Whether to apply lens correction.
lens_thickness (float): The thickness of the lens [m].
lens_sound_speed (float): The speed of sound in the lens [m/s].
sound_speed (float): The speed of sound in the medium [m/s].
n_ax (int): The number of samples in the RF data.
center_frequency (float): The center frequency of the pulse [Hz].
sampling_frequency (float): The sampling frequency of the RF data [Hz].
t0_delays (array-like): The delays of the transmitting elements [s] of shape (n_tx, n_el).
initial_times (array-like): The initial times of the transmitting elements [s] of
shape (n_tx,).
element_width (float): The width of the elements [m].
attenuation_coef (float): The attenuation coefficient [dB/cm/MHz].
tx_apodizations (array-like): The apodizations of the transmitting elements of
shape (n_tx, n_el).
Returns:
rf_data (array-like): The simulated RF data of shape (1, n_tx, n_ax, n_el, 1).
"""
n_tx = t0_delays.shape[0]
pulse_spectrum_fn = get_pulse_spectrum_fn(center_frequency, n_period=4)
if not apply_lens_correction:
dist = ops.linalg.norm(probe_geometry[None] - scatterer_positions[:, None], axis=-1)
else:
dist = (
compute_lens_corrected_travel_times(
probe_geometry,
scatterer_positions,
lens_thickness=lens_thickness,
c_lens=lens_sound_speed,
c_medium=sound_speed,
n_iter=3,
)
* sound_speed
)
n_ax_rounded = _round_up_to_power_of_two(int(n_ax)).astype("float32")
freqs = ops.arange(n_ax_rounded // 2 + 1, dtype="float32") / n_ax_rounded * sampling_frequency
waveform_spectrum = pulse_spectrum_fn(freqs)
parts = []
for tx in range(n_tx):
tx_idx = ops.array(tx)
# [n_scat, n_txel, rxel]
dist_total = dist[:, None] + dist[:, :, None]
# [n_scat, n_txel, n_rxel]
tau_total = (
(dist_total / sound_speed) + t0_delays[tx_idx][None, :, None] - initial_times[tx_idx]
)
scat_pos_relative_to_probe = scatterer_positions[:, None] - probe_geometry[None]
# Compute 3D directivity
theta = ops.arctan2(
scat_pos_relative_to_probe[:, :, 0], scat_pos_relative_to_probe[:, :, 2]
)
phi = ops.arctan2(scat_pos_relative_to_probe[:, :, 1], scat_pos_relative_to_probe[:, :, 2])
directivity_tx = directivity(
freqs[None, None, None],
theta[..., None, None],
element_width,
sound_speed,
) * directivity(
freqs[None, None, None],
phi[..., None, None],
element_width,
sound_speed,
)
directivity_rx = directivity(
freqs[None, None, None],
theta[:, None, :, None],
element_width,
sound_speed,
) * directivity(
freqs[None, None, None],
phi[:, None, :, None],
element_width,
sound_speed,
)
attenuation = attenuate(
freqs[None, None, None],
attenuation_coef=attenuation_coef,
dist=dist_total[..., None],
)
spread_atten = spread(dist_total[..., None])
result = (
waveform_spectrum[None, None, None]
* delay2(
freqs[None, None, None],
tau_total[..., None],
n_fft=n_ax_rounded,
sampling_frequency=sampling_frequency,
)
* ops.cast(
scatterer_magnitudes[:, None, None, None]
* tx_apodizations[tx, None, :, None, None]
* directivity_tx
* directivity_rx
* attenuation
* spread_atten,
"complex64",
)
)
# Sum over all transmitting elements and scatterers
result = ops.sum(result, axis=[0, 1])
result = ops.irfft((ops.real(result), ops.imag(result)))
parts.append(result)
rf_data = ops.stack(parts, axis=0)
rf_data = ops.transpose(rf_data, (0, 2, 1))
rf_data = rf_data[None, ..., None]
rf_data = rf_data[:, :, :n_ax, :, :]
return rf_data
[docs]
def directivity(f, theta, element_width, sound_speed, rigid_baffle=True):
"""Computes the directivity of a single element.
Args:
f (array-like): The input frequencies [Hz].
theta (array-like): The angles [rad].
element_width (float): The width of the element [m].
sound_speed (float): The speed of sound [m/s].
rigid_baffle (bool): Whether the element is mounted on a rigid baffle,
impacting the directivity.
Returns:
array-like: The directivity of the element.
"""
wavelength = sound_speed / f
response = sinc(element_width / wavelength * ops.sin(theta))
if not rigid_baffle:
response *= ops.cos(theta)
return response
[docs]
def delay2(f, tau, n_fft, sampling_frequency):
"""
Applies a delay in the frequency domain without phase wrapping.
Args:
f (array-like): The input frequencies.
tau (float): The delay to apply.
n_fft (int): The number of samples in the FFT.
sampling_frequency (float): The sampling frequency.
Returns:
array-like: The spectrum of the delay.
"""
arg = ops.array(-1j, dtype="complex64") * ops.cast(2 * np.pi * tau * f, "complex64")
return ops.where(
tau < n_fft / sampling_frequency,
ops.exp(arg),
ops.array(0.0, dtype="complex64"),
)
[docs]
def attenuate(f, attenuation_coef, dist):
"""
Applies attenuation to the signal in the frequency domain.
Args:
f (array-like): The input frequencies.
attenuation_coef (float): The attenuation coefficient in dB/cm/MHz.
dist (float): The distance the signal has traveled.
Returns:
array-like: The spectrum of the attenuation.
"""
return ops.exp(-ops.log(10) * attenuation_coef / 20 * dist * 100 * ops.abs(f) * 1e-6)
[docs]
def spread(dist, mindist=1e-4):
"""Function modeling geometric spreading of the wavefront.
Args:
dist (array-like): The distance the wave has traveled.
mindist (float): The minimum distance to prevent division by zero.
Returns:
array-like: The geometric spreading factor of same shape as `dist`.
"""
dist = ops.clip(dist, mindist, float("inf"))
return mindist / dist
[docs]
def hann_fd(f, width):
"""The fourier transform of a hann window in the time domain with given width."""
denom = 1.0 - (f * width) ** 2
num = 0.5 * sinc(f * width)
result = num / denom
result = ops.where(ops.abs(result) > 1.1, 0.25, result)
return ops.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.25)
[docs]
def hann_unnormalized(x, width):
"""Hann window function that is 1 at the peak. This means that the integral of the
window function is not necessarily 1.
Args:
x (array-like): The input values.
width (float): The width of the window. This is the total width from -x to x. The
window will be nonzero in the range [-width/2, width/2].
Returns:
hann_vals (array-like): The values of the Hann window function.
"""
return ops.where(ops.abs(x) < width / 2, ops.cos(np.pi * x / width) ** 2, 0)
[docs]
def get_pulse_spectrum_fn(center_frequency, n_period=3.0):
"""Computes the spectrum of a sine that is windowed with a Hann window.
Args:
center_frequency (float): The center frequency of the pulse.
n_period (float): The number of periods to include in the pulse.
Returns:
spectrum_fn (callable): A function that computes the spectrum of the pulse
for the input frequencies in Hz.
"""
period = n_period / center_frequency
def spectrum_fn(f):
return ops.array(1 / 2, "complex64") * ops.cast(
(hann_fd(f - center_frequency, period) + hann_fd(f + center_frequency, period)),
"complex64",
)
return spectrum_fn
[docs]
def get_transducer_bandwidth_fn(center_frequency, bandwidth):
"""Computes the spectrum of a probe with a center frequency and bandwidth.
Args:
center_frequency (float): The center frequency of the probe.
bandwidth (float): The bandwidth of the probe.
Returns
spectrum_fn (callable): A function that computes the spectrum of the pulse for
the input frequencies in Hz.
"""
def bandwidth_fn(f):
return hann_unnormalized(ops.abs(f) - center_frequency, bandwidth)
return bandwidth_fn
[docs]
def sinc(x):
"""The normalized sinc function with a small offset to prevent division by zero."""
x = ops.abs(np.pi * x) + 1e-9
return ops.sin(x) / x
def _round_up_to_power_of_two(x):
"""Rounds up to the next power of two."""
return 2 ** np.ceil(np.log2(x))