"""Structure containing parameters defining an ultrasound scan.
This module provides the :class:`Scan` class, a flexible structure
for managing all parameters related to an ultrasound scan acquisition.
Features
^^^^^^^^
- **Flexible initialization:** The :class:`Scan` class supports lazy initialization,
allowing you to specify any combination of supported parameters. You can pass only
the parameters you have, and the rest will be computed or set to defaults as needed.
- **Automatic computation:** Many scan properties (such as
grid, number of pixels, wavelength, etc.) are computed automatically from the
provided parameters. This enables you to work with minimal input and still obtain
all necessary scan configuration details.
- **Dependency tracking and lazy evaluation:** Derived properties are computed only
when accessed, and are automatically invalidated and recomputed if their dependencies
change. This ensures efficient memory usage and avoids unnecessary computations.
- **Parameter validation:** All parameters are type-checked and validated against
a predefined schema, reducing errors and improving robustness.
- **Selection of transmits:** The scan supports flexible selection of transmit events,
using the :meth:`set_transmits` method. You can select all, a specific number,
or specific transmit indices. The selection is stored and can be accessed via
the :attr:`selected_transmits` property.
Comparison to ``zea.Config`` and ``zea.Probe``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- :class:`zea.config.Config`: A general-purpose parameter dictionary for experiment and pipeline
configuration. It is not specific to ultrasound acquisition and does not compute
derived parameters.
- :class:`zea.probes.Probe`: Contains only probe-specific parameters (e.g., geometry, frequency).
- :class:`zea.scan.Scan`: Combines all parameters relevant to an ultrasound acquisition,
including probe, acquisition, and scan region. It also provides automatic computation
of derived properties and dependency management.
Example Usage
^^^^^^^^^^^^^
.. code-block:: python
from zea import Config, Probe, Scan
# Initialize Scan from a Probe's parameters
probe = Probe.from_name("verasonics_l11_4v")
scan = Scan(**probe.get_parameters(), grid_size_z=256)
# Or initialize from a Config object
config = Config.from_hf("zeahub/configs", "config_picmus_rf.yaml", repo_type="dataset")
scan = Scan(**config.scan, n_tx=11)
# Or manually specify parameters
scan = Scan(
grid_size_x=128,
grid_size_z=256,
xlims=(-0.02, 0.02),
zlims=(0.0, 0.06),
center_frequency=6.25e6,
sound_speed=1540.0,
sampling_frequency=25e6,
n_el=128,
n_tx=11,
)
# Access a derived property (computed lazily)
grid = scan.grid # shape: (grid_size_z, grid_size_x, 3)
# Select a subset of transmit events
scan.set_transmits(3) # Use 3 evenly spaced transmits
scan.set_transmits([0, 2, 4]) # Use specific transmit indices
scan.set_transmits("all") # Use all transmits
"""
import numpy as np
from keras import ops
from zea import log
from zea.beamform.pfield import compute_pfield
from zea.beamform.pixelgrid import cartesian_pixel_grid, check_for_aliasing, polar_pixel_grid
from zea.display import (
compute_scan_convert_2d_coordinates,
compute_scan_convert_3d_coordinates,
)
from zea.internal.core import DEFAULT_DYNAMIC_RANGE
from zea.internal.parameters import Parameters, cache_with_dependencies
[docs]
class Scan(Parameters):
"""Represents an ultrasound scan configuration with computed properties.
Args:
grid_size_x (int): Grid width in pixels. For a cartesian grid, this is the lateral (x)
pixels in the grid, set to prevent aliasing if not provided. For a polar grid, this can
be thought of as the number for rays in the polar direction.
grid_size_z (int): Grid height in pixels. This is the number of axial (z) pixels in the
grid, set to prevent aliasing if not provided.
sound_speed (float, optional): Speed of sound in the medium in m/s.
Defaults to 1540.0.
sampling_frequency (float): Sampling frequency in Hz.
center_frequency (float): Center frequency of the transducer in Hz.
n_el (int): Number of elements in the transducer array.
n_tx (int): Number of transmit events in the dataset.
n_ax (int): Number of axial samples in the received signal.
n_ch (int, optional): Number of channels (1 for RF, 2 for IQ data).
xlims (tuple of float): Lateral (x) limits of the imaging region in
meters (min, max).
ylims (tuple of float, optional): Elevation (y) limits of the imaging
region in meters (min, max).
zlims (tuple of float): Axial (z) limits of the imaging region
in meters (min, max).
probe_geometry (np.ndarray): Element positions as array of shape (n_el, 3).
polar_angles (np.ndarray): Polar angles for each transmit event in radians of shape (n_tx,).
These angles are often used in 2D imaging.
azimuth_angles (np.ndarray): Azimuth angles for each transmit event in radians
of shape (n_tx,). These angles are often used in 3D imaging.
t0_delays (np.ndarray): Transmit delays in seconds of
shape (n_tx, n_el), shifted such that the smallest delay is 0.
tx_apodizations (np.ndarray): Transmit apodizations of shape (n_tx, n_el).
focus_distances (np.ndarray): Focus distances in meters for each event of shape (n_tx,).
initial_times (np.ndarray): Initial times in seconds for each event of shape (n_tx,).
bandwidth_percent (float, optional): Bandwidth as percentage of center
frequency. Defaults to 200.0.
demodulation_frequency (float, optional): Demodulation frequency in Hz.
time_to_next_transmit (np.ndarray): The time between subsequent
transmit events of shape (n_frames, n_tx).
pixels_per_wavelength (int, optional): Number of pixels per wavelength.
Defaults to 4.
element_width (float, optional): Width of each transducer element in meters.
resolution (float, optional): Resolution for scan conversion in mm / pixel.
If None, it is calculated based on the input image.
pfield_kwargs (dict, optional): Additional parameters for pressure field computation.
See `zea.beamform.pfield.compute_pfield` for details.
apply_lens_correction (bool, optional): Whether to apply lens correction to
delays. Defaults to False.
lens_thickness (float, optional): Thickness of the lens in meters.
Defaults to None.
f_number (float, optional): F-number of the transducer. Defaults to 1.0.
theta_range (tuple, optional): Range of theta angles for 3D imaging.
phi_range (tuple, optional): Range of phi angles for 3D imaging.
rho_range (tuple, optional): Range of rho (radial) distances for 3D imaging.
fill_value (float, optional): Value to use for out-of-bounds pixels.
Defaults to 0.0.
attenuation_coef (float, optional): Attenuation coefficient in dB/(MHz*cm).
Defaults to 0.0.
selected_transmits (None, str, int, list, or np.ndarray, optional):
Specifies which transmit events to select.
- None or "all": Use all transmits.
- "center": Use only the center transmit.
- int: Select this many evenly spaced transmits.
- list/array: Use these specific transmit indices.
grid_type (str, optional): Type of grid to use for beamforming.
Can be "cartesian" or "polar". Defaults to "cartesian".
dynamic_range (tuple, optional): Dynamic range for image display.
Defined in dB as (min_dB, max_dB). Defaults to (-60, 0).
"""
VALID_PARAMS = {
# beamforming related parameters
"grid_size_x": {"type": int},
"grid_size_z": {"type": int},
"xlims": {"type": (tuple, list)},
"ylims": {"type": (tuple, list)},
"zlims": {"type": (tuple, list)},
"pixels_per_wavelength": {"type": int, "default": 4},
"pfield_kwargs": {"type": dict, "default": {}},
"apply_lens_correction": {"type": bool, "default": False},
"lens_sound_speed": {"type": (float, int)},
"lens_thickness": {"type": float},
"grid_type": {"type": str, "default": "cartesian"},
"polar_limits": {"type": (tuple, list)},
"dynamic_range": {"type": (tuple, list), "default": DEFAULT_DYNAMIC_RANGE},
# acquisition parameters
"sound_speed": {"type": (float, int), "default": 1540.0},
"sampling_frequency": {"type": float},
"center_frequency": {"type": float},
"n_el": {"type": int},
"n_tx": {"type": int},
"n_ax": {"type": int},
"n_ch": {"type": int},
"bandwidth_percent": {"type": float, "default": 200.0},
"demodulation_frequency": {"type": float},
"element_width": {"type": float},
"attenuation_coef": {"type": float, "default": 0.0},
"f_number": {"type": float, "default": 1.0},
# array parameters
"probe_geometry": {"type": np.ndarray},
"polar_angles": {"type": np.ndarray},
"azimuth_angles": {"type": np.ndarray},
"t0_delays": {"type": np.ndarray},
"tx_apodizations": {"type": np.ndarray},
"focus_distances": {"type": np.ndarray},
"initial_times": {"type": np.ndarray},
"time_to_next_transmit": {"type": np.ndarray},
# scan conversion parameters
"theta_range": {"type": (tuple, list)},
"phi_range": {"type": (tuple, list)},
"rho_range": {"type": (tuple, list)},
"fill_value": {"type": float, "default": 0.0},
"resolution": {"type": float, "default": None},
}
def __init__(self, **kwargs):
# Store the current selection state before initialization
selected_transmits_input = kwargs.pop("selected_transmits", None)
# Initialize parent class
super().__init__(**kwargs)
# Initialize selection to None
self._selected_transmits = None
# Apply selection from input if provided
if selected_transmits_input is not None:
self.set_transmits(selected_transmits_input)
@cache_with_dependencies(
"xlims",
"zlims",
"grid_size_x",
"grid_size_z",
"sound_speed",
"center_frequency",
"pixels_per_wavelength",
"grid_type",
)
def grid(self):
"""The beamforming grid of shape (grid_size_z, grid_size_x, 3)."""
if self.grid_type == "polar":
return polar_pixel_grid(
self.polar_limits, self.zlims, self.grid_size_z, self.grid_size_x
)
elif self.grid_type == "cartesian":
return cartesian_pixel_grid(
self.xlims, self.zlims, grid_size_z=self.grid_size_z, grid_size_x=self.grid_size_x
)
else:
raise ValueError(
f"Unsupported grid type: {self.grid_type}. Supported types are "
"'cartesian' and 'polar'."
)
@cache_with_dependencies(
"xlims",
"wavelength",
"pixels_per_wavelength",
"grid_type",
)
def grid_size_x(self):
"""Grid width in pixels. For a cartesian grid, this is the lateral (x) pixels in the grid,
set to prevent aliasing if not provided. For a polar grid, this can be thought of as
the number for rays in the polar direction.
"""
grid_size_x = self._params.get("grid_size_x")
if grid_size_x is not None:
return grid_size_x
width = self.xlims[1] - self.xlims[0]
min_grid_size_x = int(np.ceil(width / (self.wavelength / self.pixels_per_wavelength)))
return max(min_grid_size_x, 1)
@cache_with_dependencies(
"zlims",
"wavelength",
"pixels_per_wavelength",
)
def grid_size_z(self):
"""Grid height in pixels. This is the number of axial (z) pixels in the grid,
set to prevent aliasing if not provided."""
grid_size_z = self._params.get("grid_size_z")
if grid_size_z is not None:
return grid_size_z
depth = self.zlims[1] - self.zlims[0]
min_grid_size_z = int(np.ceil(depth / (self.wavelength / self.pixels_per_wavelength)))
return max(min_grid_size_z, 1)
@cache_with_dependencies("sound_speed", "center_frequency")
def wavelength(self):
"""Calculate the wavelength based on sound speed and center frequency."""
return self.sound_speed / self.center_frequency
@cache_with_dependencies("zlims", "grid_type", "polar_limits", "probe_geometry")
def xlims(self):
"""The x-limits of the beamforming grid [m]. If not explicitly set, it is computed based
on the polar limits and probe geometry.
"""
xlims = self._params.get("xlims")
if xlims is None:
radius = max(self.zlims)
xlims_polar = (
radius * np.cos(-np.pi / 2 + self.polar_limits[0]),
radius * np.cos(-np.pi / 2 + self.polar_limits[1]),
)
xlims_plane = (self.probe_geometry[0, 0], self.probe_geometry[-1, 0])
xlims = min(xlims_polar[0], xlims_plane[0]), max(xlims_polar[1], xlims_plane[1])
return xlims
@cache_with_dependencies("sound_speed", "sampling_frequency", "n_ax")
def zlims(self):
"""The z-limits of the beamforming grid [m]."""
zlims = self._params.get("zlims")
if zlims is None:
return [0, self.sound_speed * self.n_ax / self.sampling_frequency / 2]
return zlims
@cache_with_dependencies("xlims", "zlims")
def extent(self):
"""The extent of the beamforming grid in the format (xmin, xmax, zmax, zmin).
Can be directly used with `plt.imshow(x, extent=scan.extent)` for visualization.
"""
return np.array([self.xlims[0], self.xlims[1], self.zlims[1], self.zlims[0]])
@cache_with_dependencies("grid")
def flatgrid(self):
"""The beamforming grid of shape (grid_size_z*grid_size_x, 3)."""
return self.grid.reshape(-1, 3)
@property
def selected_transmits(self):
"""Get the currently selected transmit indices.
Returns:
list: The list of selected transmit indices. If none were explicitly
selected and n_tx is available, all transmits are used.
"""
# Return all transmits if none explicitly selected
if self._selected_transmits is None:
if "n_tx" in self._params:
return list(range(self._params["n_tx"]))
return []
return self._selected_transmits
@property
def n_tx_total(self):
"""The total number of transmits in the full dataset."""
return self._params["n_tx"]
@cache_with_dependencies("selected_transmits")
def n_tx(self):
"""The number of currently selected transmits."""
return len(self.selected_transmits)
[docs]
def set_transmits(self, selection):
"""Select which transmit events to use.
This method provides flexible ways to select transmit events:
Args:
selection: Specifies which transmits to select:
- None: Use all transmits
- "all": Use all transmits
- "center": Use only the center transmit
- int: Select this many evenly spaced transmits
- list/array: Use these specific transmit indices
Returns:
The current instance for method chaining.
Raises:
ValueError: If the selection is invalid or incompatible with the scan.
"""
n_tx_total = self._params.get("n_tx")
if n_tx_total is None:
raise ValueError("n_tx must be set before calling set_transmits")
# Handle array-like - convert to list of indices
if isinstance(selection, np.ndarray):
if len(selection.shape) == 0:
# Handle scalar numpy array
return self.set_transmits(int(selection))
elif len(selection.shape) == 1:
selection = selection.tolist()
else:
raise ValueError(f"Invalid array shape: {selection.shape}")
# Handle None and "all" - use all transmits
if selection is None or selection == "all":
self._selected_transmits = None
self._invalidate("selected_transmits")
self._invalidate_dependents("selected_transmits")
return self
# Handle "center" - use center transmit
if selection == "center":
self._selected_transmits = [n_tx_total // 2]
self._invalidate("selected_transmits")
self._invalidate_dependents("selected_transmits")
return self
# Handle integer - select evenly spaced transmits
if isinstance(selection, (int, np.integer)):
selection = int(selection) # Convert numpy integer to Python int
if selection <= 0:
raise ValueError(f"Number of transmits must be positive, got {selection}")
if selection > n_tx_total:
raise ValueError(
f"Requested {selection} transmits exceeds available transmits ({n_tx_total})"
)
if selection == 1:
self._selected_transmits = [n_tx_total // 2]
else:
# Compute evenly spaced indices
tx_indices = np.linspace(0, n_tx_total - 1, selection)
self._selected_transmits = list(np.rint(tx_indices).astype(int))
self._invalidate("selected_transmits")
self._invalidate_dependents("selected_transmits")
return self
# Handle list of indices
if isinstance(selection, list):
# Validate indices
if not all(isinstance(i, (int, np.integer)) for i in selection):
raise ValueError("All transmit indices must be integers")
if any(i < 0 or i >= n_tx_total for i in selection):
raise ValueError(f"Transmit indices must be between 0 and {n_tx_total - 1}")
self._selected_transmits = [
int(i) for i in selection
] # Convert numpy integers to Python ints
self._invalidate("selected_transmits")
self._invalidate_dependents("selected_transmits")
return self
# Aliasing check
check_for_aliasing(self)
raise ValueError(f"Unsupported selection type: {type(selection)}")
@cache_with_dependencies("n_ch", "center_frequency")
def demodulation_frequency(self):
"""The demodulation frequency."""
if self._params.get("demodulation_frequency") is not None:
return self._params["demodulation_frequency"]
# Default behavior based on n_ch
return self.center_frequency if self.n_ch == 2 else 0.0
@cache_with_dependencies("selected_transmits")
def polar_angles(self):
"""Polar angles for each transmit event in radians of shape (n_tx,).
These angles are often used in 2D imaging."""
value = self._params.get("polar_angles")
if value is None:
return None
return value[self.selected_transmits]
@cache_with_dependencies("polar_angles")
def polar_limits(self):
"""The limits of the polar angles."""
value = self._params.get("polar_limits")
if value is None and self.polar_angles is not None:
value = self.polar_angles.min(), self.polar_angles.max()
diff = value[1] - value[0]
# add 15% margin to the limits
value = (value[0] - 0.15 * diff, value[1] + 0.15 * diff)
return value
@cache_with_dependencies("selected_transmits")
def azimuth_angles(self):
"""Azimuth angles for each transmit event in radians
of shape (n_tx,). These angles are often used in 3D imaging."""
value = self._params.get("azimuth_angles")
if value is None:
log.warning("No azimuth angles provided, using zeros")
value = np.zeros(self.n_tx_total)
return value[self.selected_transmits]
@cache_with_dependencies("selected_transmits", "n_el")
def t0_delays(self):
"""Transmit delays in seconds of
shape (n_tx, n_el), shifted such that the smallest delay is 0."""
value = self._params.get("t0_delays")
if value is None:
log.warning("No transmit delays provided, using zeros")
return np.zeros((self.n_tx_total, self.n_el))
return value[self.selected_transmits]
@cache_with_dependencies("selected_transmits")
def tx_apodizations(self):
"""Transmit apodizations of shape (n_tx, n_el)."""
value = self._params.get("tx_apodizations")
if value is None:
log.warning("No transmit apodizations provided, using ones")
value = np.ones((self.n_tx_total, self.n_el))
return value[self.selected_transmits]
@cache_with_dependencies("selected_transmits")
def focus_distances(self):
"""Focus distances in meters for each event of shape (n_tx,)."""
value = self._params.get("focus_distances")
if value is None:
log.warning("No focus distances provided, using zeros")
value = np.zeros(self.n_tx_total)
return value[self.selected_transmits]
@cache_with_dependencies("selected_transmits")
def initial_times(self):
"""Initial times in seconds for each event of shape (n_tx,)."""
value = self._params.get("initial_times")
if value is None:
log.warning("No initial times provided, using zeros")
value = np.zeros(self.n_tx_total)
return value[self.selected_transmits]
@cache_with_dependencies("selected_transmits")
def time_to_next_transmit(self):
"""The time between subsequent transmit events of shape (n_frames, n_tx)."""
value = self._params.get("time_to_next_transmit")
if value is None:
return None
selected = self.selected_transmits
return value[:, selected]
@cache_with_dependencies(
"sound_speed",
"center_frequency",
"bandwidth_percent",
"n_el",
"probe_geometry",
"tx_apodizations",
"grid",
"t0_delays",
)
def pfield(self):
"""Compute or return the pressure field (pfield) for weighting."""
pfield = compute_pfield(
sound_speed=self.sound_speed,
center_frequency=self.center_frequency,
bandwidth_percent=self.bandwidth_percent,
n_el=self.n_el,
probe_geometry=self.probe_geometry,
tx_apodizations=self.tx_apodizations,
grid=self.grid,
t0_delays=self.t0_delays,
**self.pfield_kwargs,
)
return ops.convert_to_numpy(pfield)
@cache_with_dependencies("pfield")
def flat_pfield(self):
"""Flattened pfield for weighting."""
return self.pfield.reshape(self.n_tx, -1).swapaxes(0, 1)
@cache_with_dependencies("zlims")
def rho_range(self):
"""A tuple specifying the range of rho values (min_rho, max_rho). Defined in mm.
Used for scan conversion."""
value = self._params.get("rho_range")
if value is None:
return self.zlims
return value
@cache_with_dependencies("polar_limits")
def theta_range(self):
"""A tuple specifying the range of theta values (min_theta, max_theta).
Defined in radians. Used for scan conversion."""
value = self._params.get("theta_range")
if value is None and self.polar_limits is not None:
return self.polar_limits
return value
@cache_with_dependencies("rho_range", "theta_range", "resolution", "grid_size_z", "grid_size_x")
def coordinates_2d(self):
"""The coordinates for scan conversion."""
coords, _ = compute_scan_convert_2d_coordinates(
(self.grid_size_z, self.grid_size_x),
self.rho_range,
self.theta_range,
self.resolution,
)
return coords
@cache_with_dependencies(
"rho_range", "theta_range", "phi_range", "resolution", "grid_size_z", "grid_size_x"
)
def coordinates_3d(self):
"""The coordinates for scan conversion."""
coords, _ = compute_scan_convert_3d_coordinates(
(self.grid_size_z, self.grid_size_x),
self.rho_range,
self.theta_range,
self.phi_range,
self.resolution,
)
return coords
@property
def coordinates(self):
"""Get the coordinates for scan conversion, will be 3D if phi_range is set,
otherwise 2D."""
return self.coordinates_3d if getattr(self, "phi_range", None) else self.coordinates_2d
@cache_with_dependencies("time_to_next_transmit")
def frames_per_second(self):
"""The number of frames per second [Hz]. Assumes a constant frame rate.
Frames per second computed based on time between transmits within a frame.
Ignores time between frames (e.g. due to processing).
Uses the time it took to do all transmits (per frame). So if you only use some portion
of the transmits, the fps will still be calculated based on all.
"""
if self.time_to_next_transmit is None:
log.warning("Time to next transmit is not set, cannot compute fps")
return None
# Check if fps is constant
uniq = np.unique(self.time_to_next_transmit, axis=0) # frame axis
if uniq.shape[0] != 1:
log.warning("Time to next transmit is not constant")
# Compute fps
time = np.mean(np.sum(self.time_to_next_transmit, axis=1))
fps = 1 / time
return fps
@cache_with_dependencies("probe_geometry")
def element_width(self):
"""The width of each transducer element in meters."""
value = self._params.get("element_width")
if value is None:
return np.linalg.norm(self.probe_geometry[1] - self.probe_geometry[0])
return value
def __setattr__(self, key, value):
if key == "selected_transmits":
# If setting selected_transmits, call set_transmits to handle logic
self.set_transmits(value)
return
return super().__setattr__(key, value)