Source code for zea.ops

"""Operations and Pipelines for ultrasound data processing.

This module contains two important classes, :class:`Operation` and :class:`Pipeline`,
which are used to process ultrasound data. A pipeline is a sequence of operations
that are applied to the data in a specific order.

Stand-alone manual usage
------------------------

Operations can be run on their own:

Examples
^^^^^^^^
.. code-block:: python

    data = np.random.randn(2000, 128, 1)
    # static arguments are passed in the constructor
    envelope_detect = EnvelopeDetect(axis=-1)
    # other parameters can be passed here along with the data
    envelope_data = envelope_detect(data=data)

Using a pipeline
----------------

You can initialize with a default pipeline or create your own custom pipeline.

.. code-block:: python

    pipeline = Pipeline.from_default()

    operations = [
        EnvelopeDetect(),
        Normalize(),
        LogCompress(),
    ]
    pipeline_custom = Pipeline(operations)

One can also load a pipeline from a config or yaml/json file:

.. code-block:: python

    json_string = '{"operations": ["identity"]}'
    pipeline = Pipeline.from_json(json_string)

    yaml_file = "pipeline.yaml"
    pipeline = Pipeline.from_yaml(yaml_file)

Example of a yaml file:

.. code-block:: yaml

    pipeline:
      operations:
        - name: demodulate
        - name: "patched_grid"
          params:
            operations:
              - name: tof_correction
                params:
                  apply_phase_rotation: true
              - name: pfield_weighting
              - name: delay_and_sum
            num_patches: 100
        - name: envelope_detect
        - name: normalize
        - name: log_compress

"""

import copy
import hashlib
import inspect
import json
from functools import partial
from typing import Any, Dict, List, Union

import keras
import numpy as np
import scipy
import yaml
from keras import ops
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer

from zea import log
from zea.backend import jit
from zea.beamform.beamformer import tof_correction
from zea.config import Config
from zea.display import scan_convert
from zea.internal.checks import _assert_keys_and_axes
from zea.internal.core import (
    DEFAULT_DYNAMIC_RANGE,
    DataTypes,
    ZEADecoderJSON,
    ZEAEncoderJSON,
    dict_to_tensor,
)
from zea.internal.core import Object as ZEAObject
from zea.internal.registry import ops_registry
from zea.probes import Probe
from zea.scan import Scan
from zea.simulator import simulate_rf
from zea.tensor_ops import batched_map, patched_map, resample, reshape_axis
from zea.utils import FunctionTimer, deep_compare, map_negative_indices, translate


[docs] def get_ops(ops_name): """Get the operation from the registry.""" return ops_registry[ops_name]
[docs] class Operation(keras.Operation): """ A base abstract class for operations in the pipeline with caching functionality. """ def __init__( self, input_data_type: Union[DataTypes, None] = None, output_data_type: Union[DataTypes, None] = None, key: Union[str, None] = "data", output_key: Union[str, None] = None, cache_inputs: Union[bool, List[str]] = False, cache_outputs: bool = False, jit_compile: bool = True, with_batch_dim: bool = True, jit_kwargs: dict | None = None, jittable: bool = True, **kwargs, ): """ Args: input_data_type (DataTypes): The data type of the input data output_data_type (DataTypes): The data type of the output data key: The key for the input data (operation will operate on this key) Defaults to "data". output_key: The key for the output data (operation will output to this key) Defaults to the same as the input key. If you want to store intermediate results, you can set this to a different key. But make sure to update the input key of the next operation to match the output key of this operation. cache_inputs: A list of input keys to cache or True to cache all inputs cache_outputs: A list of output keys to cache or True to cache all outputs jit_compile: Whether to JIT compile the 'call' method for faster execution with_batch_dim: Whether operations should expect a batch dimension in the input jit_kwargs: Additional keyword arguments for the JIT compiler jittable: Whether the operation can be JIT compiled """ super().__init__(**kwargs) self.input_data_type = input_data_type self.output_data_type = output_data_type self.key = key # Key for input data self.output_key = output_key # Key for output data if self.output_key is None: self.output_key = self.key self.inputs = [] # Source(s) of input data (name of a previous operation) self.allow_multiple_inputs = False # Only single input allowed by default self.cache_inputs = cache_inputs self.cache_outputs = cache_outputs # Initialize input and output caches self._input_cache = {} self._output_cache = {} # Obtain the input signature of the `call` method self._input_signature = None self._valid_keys = None # Keys valid for the `call` method self._trace_signatures() if jit_kwargs is None: jit_kwargs = {} if keras.backend.backend() == "jax" and self.static_params: jit_kwargs |= {"static_argnames": self.static_params} self.jit_kwargs = jit_kwargs self.with_batch_dim = with_batch_dim self._jittable = jittable # Set the jit compilation flag and compile the `call` method # Set zea logger level to suppress warnings regarding # torch not being able to compile the function with log.set_level("ERROR"): self.set_jit(jit_compile) @property def static_params(self): """Get the static parameters of the operation.""" return getattr(self.__class__, "STATIC_PARAMS", [])
[docs] def set_jit(self, jit_compile: bool): """Set the JIT compilation flag and set the `_call` method accordingly.""" self._jit_compile = jit_compile if self._jit_compile and self.jittable: self._call = jit(self.call, **self.jit_kwargs) else: self._call = self.call
def _trace_signatures(self): """ Analyze and store the input/output signatures of the `call` method. """ self._input_signature = inspect.signature(self.call) self._valid_keys = set(self._input_signature.parameters.keys()) @property def valid_keys(self): """Get the valid keys for the `call` method.""" return self._valid_keys @property def jittable(self): """Check if the operation can be JIT compiled.""" return self._jittable
[docs] def call(self, **kwargs): """ Abstract method that defines the processing logic for the operation. Subclasses must implement this method. """ raise NotImplementedError
[docs] def set_input_cache(self, input_cache: Dict[str, Any]): """ Set a cache for inputs, then retrace the function if necessary. Args: input_cache: A dictionary containing cached inputs. """ self._input_cache.update(input_cache) self._trace_signatures() # Retrace after updating cache to ensure correctness.
[docs] def set_output_cache(self, output_cache: Dict[str, Any]): """ Set a cache for outputs, then retrace the function if necessary. Args: output_cache: A dictionary containing cached outputs. """ self._output_cache.update(output_cache) self._trace_signatures() # Retrace after updating cache to ensure correctness.
[docs] def clear_cache(self): """ Clear the input and output caches. """ self._input_cache.clear() self._output_cache.clear()
def _hash_inputs(self, kwargs: Dict) -> str: """ Generate a hash for the given inputs to use as a cache key. Args: kwargs: Keyword arguments. Returns: A unique hash representing the inputs. """ input_json = json.dumps(kwargs, sort_keys=True, default=str) return hashlib.md5(input_json.encode()).hexdigest() def __call__(self, *args, **kwargs) -> Dict: """ Process the input keyword arguments and return the processed results. Args: kwargs: Keyword arguments to be processed. Returns: Combined input and output as kwargs. """ if args: example_usage = f" result = {ops_registry.get_name(self)}({self.key}=my_data" valid_keys_no_kwargs = self.valid_keys - {"kwargs"} if valid_keys_no_kwargs: example_usage += f", {list(valid_keys_no_kwargs)[0]}=param1, ..., **kwargs)" else: example_usage += ", **kwargs)" raise TypeError( f"{self.__class__.__name__}.__call__() only accepts keyword arguments. " "Positional arguments are not allowed.\n" f"Received positional arguments: {args}\n" "Example usage:\n" f"{example_usage}" ) # Merge cached inputs with provided ones merged_kwargs = {**self._input_cache, **kwargs} # Return cached output if available if self.cache_outputs: cache_key = self._hash_inputs(merged_kwargs) if cache_key in self._output_cache: return {**merged_kwargs, **self._output_cache[cache_key]} # Filter kwargs to match the valid keys of the `call` method if "kwargs" not in self.valid_keys: filtered_kwargs = {k: v for k, v in merged_kwargs.items() if k in self.valid_keys} else: filtered_kwargs = merged_kwargs # Call the processing function # If you want to jump in with debugger please set `jit_compile=False` # when initializing the pipeline. processed_output = self._call(**filtered_kwargs) # Ensure the output is always a dictionary if not isinstance(processed_output, dict): raise TypeError( f"The `call` method must return a dictionary. Got {type(processed_output)}." ) # Merge outputs with inputs combined_kwargs = {**merged_kwargs, **processed_output} # Cache the result if caching is enabled if self.cache_outputs: if isinstance(self.cache_outputs, list): cached_output = { k: v for k, v in processed_output.items() if k in self.cache_outputs } else: cached_output = processed_output self._output_cache[cache_key] = cached_output return combined_kwargs
[docs] def get_dict(self): """Get the configuration of the operation. Inherit from keras.Operation.""" config = {} config.update({"name": ops_registry.get_name(self)}) config["params"] = { "key": self.key, "output_key": self.output_key, "cache_inputs": self.cache_inputs, "cache_outputs": self.cache_outputs, "jit_compile": self._jit_compile, "with_batch_dim": self.with_batch_dim, "jit_kwargs": self.jit_kwargs, } return config
def __eq__(self, other): """Check equality of two operations based on type and configuration.""" if not isinstance(other, Operation): return False # Compare the class name and parameters if self.__class__.__name__ != other.__class__.__name__: return False # Compare the name assigned to the operation name = ops_registry.get_name(self) other_name = ops_registry.get_name(other) if name != other_name: return False # Compare the parameters of the operations if not deep_compare(self.get_dict(), other.get_dict()): return False return True
[docs] @ops_registry("pipeline") class Pipeline: """Pipeline class for processing ultrasound data through a series of operations.""" def __init__( self, operations: List[Operation], with_batch_dim: bool = True, jit_options: Union[str, None] = "ops", jit_kwargs: dict | None = None, name="pipeline", validate=True, timed: bool = False, ): """Initialize a pipeline Args: operations (list): A list of Operation instances representing the operations to be performed. with_batch_dim (bool, optional): Whether operations should expect a batch dimension. Defaults to True. jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None. - "pipeline" compiles the entire pipeline as a single function. This may be faster but, does not preserve python control flow, such as caching. - "ops" compiles each operation separately. This preserves python control flow and caching functionality, but speeds up the operations. - None disables JIT compilation. Defaults to "ops". jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler. name (str, optional): The name of the pipeline. Defaults to "pipeline". validate (bool, optional): Whether to validate the pipeline. Defaults to True. """ self._call_pipeline = self.call self.name = name self.timer = FunctionTimer() self.timed = timed self._pipeline_layers = operations if jit_options not in ["pipeline", "ops", None]: raise ValueError("jit_options must be 'pipeline', 'ops', or None") self.with_batch_dim = with_batch_dim self._validate_flag = validate if validate: self.validate() else: log.warning("Pipeline validation is disabled, make sure to validate manually.") if jit_kwargs is None: jit_kwargs = {} if keras.backend.backend() == "jax" and self.static_params: jit_kwargs = {"static_argnames": self.static_params} self.jit_kwargs = jit_kwargs self.jit_options = jit_options # will handle the jit compilation
[docs] def needs(self, key) -> bool: """Check if the pipeline needs a specific key.""" return key in self.valid_keys
@property def valid_keys(self) -> set: """Get a set of valid keys for the pipeline.""" valid_keys = set() for operation in self.operations: valid_keys.update(operation.valid_keys) return valid_keys @property def static_params(self) -> List[str]: """Get a list of static parameters for the pipeline.""" static_params = [] for operation in self.operations: static_params.extend(operation.static_params) return list(set(static_params))
[docs] @classmethod def from_default(cls, num_patches=100, baseband=False, pfield=False, **kwargs) -> "Pipeline": """Create a default pipeline. Args: num_patches (int): Number of patches for the PatchedGrid operation. Defaults to 100. If you get an out of memory error, try to increase this number. baseband (bool): If True, assume the input data is baseband (I/Q) data, which has 2 channels (last dim). Defaults to False, which assumes RF data, so input signal has a single channel dim and is still on carrier frequency. pfield (bool): If True, apply Pfield weighting. Defaults to False. This will calculate pressure field and only beamform the data to those locations. **kwargs: Additional keyword arguments to be passed to the Pipeline constructor. """ operations = [] # Add the demodulate operation if not baseband: operations.append(Demodulate()) # Get beamforming ops beamforming = [ TOFCorrection(apply_phase_rotation=True), DelayAndSum(), ] if pfield: beamforming.insert(1, PfieldWeighting()) # Optionally add patching if num_patches > 1: beamforming = [PatchedGrid(operations=beamforming, num_patches=num_patches, **kwargs)] # Add beamforming ops operations += beamforming # Add display ops operations += [ EnvelopeDetect(), Normalize(), LogCompress(), ] return cls(operations, **kwargs)
[docs] def copy(self) -> "Pipeline": """Create a copy of the pipeline.""" return Pipeline( self._pipeline_layers.copy(), with_batch_dim=self.with_batch_dim, jit_options=self.jit_options, jit_kwargs=self.jit_kwargs, name=self.name, validate=self._validate_flag, )
[docs] def prepend(self, operation: Operation): """Prepend an operation to the pipeline.""" self._pipeline_layers.insert(0, operation) self.copy()
[docs] def append(self, operation: Operation): """Append an operation to the pipeline.""" self._pipeline_layers.append(operation) self.copy()
[docs] def insert(self, index: int, operation: Operation): """Insert an operation at a specific index in the pipeline.""" if index < 0 or index > len(self._pipeline_layers): raise IndexError("Index out of bounds for inserting operation.") self._pipeline_layers.insert(index, operation) return self.copy()
@property def operations(self): """Alias for self.layers to match the zea naming convention""" return self._pipeline_layers
[docs] def timed_call(self, **inputs): """Process input data through the pipeline.""" for op in self._pipeline_layers: timed_op = self.timer(op, name=op.__class__.__name__) try: outputs = timed_op(**inputs) except KeyError as exc: raise KeyError( f"[zea.Pipeline] Operation '{op.__class__.__name__}' " f"requires input key '{exc.args[0]}', " "but it was not provided in the inputs.\n" "Check whether the objects (such as `zea.Scan`) passed to " "`pipeline.prepare_parameters()` contain all required keys.\n" f"Current list of all passed keys: {list(inputs.keys())}\n" f"Valid keys for this pipeline: {self.valid_keys}" ) from exc except Exception as exc: raise RuntimeError( f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}" ) from exc inputs = outputs return outputs
[docs] def call(self, **inputs): """Process input data through the pipeline.""" for operation in self._pipeline_layers: try: outputs = operation(**inputs) except KeyError as exc: raise KeyError( f"[zea.Pipeline] Operation '{operation.__class__.__name__}' " f"requires input key '{exc.args[0]}', " "but it was not provided in the inputs.\n" "Check whether the objects (such as `zea.Scan`) passed to " "`pipeline.prepare_parameters()` contain all required keys.\n" f"Current list of all passed keys: {list(inputs.keys())}\n" f"Valid keys for this pipeline: {self.valid_keys}" ) from exc except Exception as exc: raise RuntimeError( f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}" ) inputs = outputs return outputs
def __call__(self, return_numpy=False, **inputs): """Process input data through the pipeline.""" if any(key in inputs for key in ["probe", "scan", "config"]): raise ValueError( "Probe, Scan and Config objects should be first processed with " "`Pipeline.prepare_parameters` before calling the pipeline. " "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)" ) if any(isinstance(arg, ZEAObject) for arg in inputs.values()): raise ValueError( "Probe, Scan and Config objects should be first processed with " "`Pipeline.prepare_parameters` before calling the pipeline. " "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)" ) if any(isinstance(arg, str) for arg in inputs.values()): raise ValueError( "Pipeline does not support string inputs. " "Please ensure all inputs are convertible to tensors." ) ## PROCESSING outputs = self._call_pipeline(**inputs) ## PREPARE OUTPUT if return_numpy: # Convert tensors to numpy arrays but preserve None values outputs = { k: ops.convert_to_numpy(v) if v is ops.is_tensor(v) else v for k, v in outputs.items() } return outputs @property def jit_options(self): """Get the jit_options property of the pipeline.""" return self._jit_options @jit_options.setter def jit_options(self, value: Union[str, None]): """Set the jit_options property of the pipeline.""" self._jit_options = value if value == "pipeline": assert self.jittable, log.error( "jit_options 'pipeline' cannot be used as the entire pipeline is not jittable. " "The following operations are not jittable: " f"{self.unjitable_ops}. " "Try setting jit_options to 'ops' or None." ) self.jit() return else: self.unjit() for operation in self.operations: if isinstance(operation, Pipeline): operation.jit_options = value else: if operation.jittable and operation._jit_compile: operation.set_jit(value == "ops") @property def _call_fn(self): """Get the call function of the pipeline.""" return self.call if not self.timed else self.timed_call
[docs] def jit(self): """JIT compile the pipeline.""" self._call_pipeline = jit(self._call_fn, **self.jit_kwargs)
[docs] def unjit(self): """Un-JIT compile the pipeline.""" self._call_pipeline = self._call_fn
@property def jittable(self): """Check if all operations in the pipeline are jittable.""" return all(operation.jittable for operation in self.operations) @property def unjitable_ops(self): """Get a list of operations that are not jittable.""" return [operation for operation in self.operations if not operation.jittable] @property def with_batch_dim(self): """Get the with_batch_dim property of the pipeline.""" return self._with_batch_dim @with_batch_dim.setter def with_batch_dim(self, value): """Set the with_batch_dim property of the pipeline.""" self._with_batch_dim = value for operation in self.operations: operation.with_batch_dim = value @property def input_data_type(self): """Get the input_data_type property of the pipeline.""" return self.operations[0].input_data_type @property def output_data_type(self): """Get the output_data_type property of the pipeline.""" return self.operations[-1].output_data_type
[docs] def validate(self): """Validate the pipeline by checking the compatibility of the operations.""" operations = self.operations for i in range(len(operations) - 1): if operations[i].output_data_type is None: continue if operations[i + 1].input_data_type is None: continue if operations[i].output_data_type != operations[i + 1].input_data_type: raise ValueError( f"Operation {operations[i].__class__.__name__} output data type " f"({operations[i].output_data_type}) is not compatible " f"with the input data type ({operations[i + 1].input_data_type}) " f"of operation {operations[i + 1].__class__.__name__}" )
[docs] def set_params(self, **params): """Set parameters for the operations in the pipeline by adding them to the cache.""" for operation in self.operations: operation_params = { key: value for key, value in params.items() if key in operation.valid_keys } if operation_params: operation.set_input_cache(operation_params)
[docs] def get_params(self, per_operation: bool = False): """Get a snapshot of the current parameters of the operations in the pipeline. Args: per_operation (bool): If True, return a list of dictionaries for each operation. If False, return a single dictionary with all parameters combined. """ if per_operation: return [operation._input_cache.copy() for operation in self.operations] else: params = {} for operation in self.operations: params.update(operation._input_cache) return params
def __str__(self): """String representation of the pipeline. Will print on two parallel pipeline lines if it detects a splitting operations (such as multi_bandpass_filter) Will merge the pipeline lines if it detects a stacking operation (such as stack) """ split_operations = [] merge_operations = ["Stack"] operations = [operation.__class__.__name__ for operation in self.operations] string = " -> ".join(operations) if any(operation in split_operations for operation in operations): # a second line is needed with same length as the first line split_line = " " * len(string) # find the splitting operation and index and print \-> instead of -> after split_detected = False merge_detected = False split_operation = None for operation in operations: if operation in split_operations: index = string.index(operation) index = index + len(operation) split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :] split_detected = True merge_detected = False split_operation = operation continue if operation in merge_operations: index = string.index(operation) index = index - 4 split_line = split_line[:index] + "/" + split_line[index + 1 :] split_detected = False merge_detected = True continue if split_detected: # print all operations in the second line index = string.index(operation) split_line = ( split_line[:index] + operation + " -> " + split_line[index + len(operation) + len(" -> ") :] ) assert merge_detected is True, log.error( "Pipeline was never merged back together (with Stack operation), even " f"though it was split with {split_operation}. " "Please properly define your operation chain." ) return f"\n{string}\n{split_line}\n" return string def __repr__(self): """String representation of the pipeline.""" operations = [] for operation in self.operations: if isinstance(operation, Pipeline): operations.append(repr(operation)) else: operations.append(operation.__class__.__name__) return f"<Pipeline {self.name}=({', '.join(operations)})>"
[docs] @classmethod def load(cls, file_path: str, **kwargs) -> "Pipeline": """Load a pipeline from a JSON or YAML file.""" if file_path.endswith(".json"): with open(file_path, "r", encoding="utf-8") as f: json_str = f.read() return pipeline_from_json(json_str, **kwargs) elif file_path.endswith(".yaml") or file_path.endswith(".yml"): return pipeline_from_yaml(file_path, **kwargs) else: raise ValueError("File must have extension .json, .yaml, or .yml")
[docs] def get_dict(self) -> dict: """Convert the pipeline to a dictionary.""" config = {} config["name"] = ops_registry.get_name(self) config["operations"] = self._pipeline_to_list(self) config["params"] = { "with_batch_dim": self.with_batch_dim, "jit_options": self.jit_options, "jit_kwargs": self.jit_kwargs, } return config
@staticmethod def _pipeline_to_list(pipeline): """Convert the pipeline to a list of operations.""" ops_list = [] for op in pipeline.operations: ops_list.append(op.get_dict()) return ops_list
[docs] @classmethod def from_config(cls, config: Dict, **kwargs) -> "Pipeline": """Create a pipeline from a dictionary or ``zea.Config`` object. Args: config (dict or Config): Configuration dictionary or ``zea.Config`` object. **kwargs: Additional keyword arguments to be passed to the pipeline. Note: Must have a ``pipeline`` key with a subkey ``operations``. Example: .. code-block:: python config = Config( { "operations": [ "identity", ], } ) pipeline = Pipeline.from_config(config) """ return pipeline_from_config(Config(config), **kwargs)
[docs] @classmethod def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline": """Create a pipeline from a YAML file. Args: file_path (str): Path to the YAML file. **kwargs: Additional keyword arguments to be passed to the pipeline. Note: Must have the a `pipeline` key with a subkey `operations`. Example: ```python pipeline = Pipeline.from_yaml("pipeline.yaml") ``` """ return pipeline_from_yaml(file_path, **kwargs)
[docs] @classmethod def from_json(cls, json_string: str, **kwargs) -> "Pipeline": """Create a pipeline from a JSON string. Args: json_string (str): JSON string representing the pipeline. **kwargs: Additional keyword arguments to be passed to the pipeline. Note: Must have the `operations` key. Example: ```python json_string = '{"operations": ["identity"]}' pipeline = Pipeline.from_json(json_string) ``` """ return pipeline_from_json(json_string, **kwargs)
[docs] def to_config(self) -> Config: """Convert the pipeline to a `zea.Config` object.""" return pipeline_to_config(self)
[docs] def to_json(self) -> str: """Convert the pipeline to a JSON string.""" return pipeline_to_json(self)
[docs] def to_yaml(self, file_path: str) -> None: """Convert the pipeline to a YAML file.""" pipeline_to_yaml(self, file_path)
@property def key(self) -> str: """Input key of the pipeline.""" return self.operations[0].key @property def output_key(self) -> str: """Output key of the pipeline.""" return self.operations[-1].output_key def __eq__(self, other): """Check if two pipelines are equal.""" if not isinstance(other, Pipeline): return False # Compare the operations in both pipelines if len(self.operations) != len(other.operations): return False for op1, op2 in zip(self.operations, other.operations): if not op1 == op2: return False return True
[docs] def prepare_parameters( self, probe: Probe = None, scan: Scan = None, config: Config = None, **kwargs, ): """Prepare Probe, Scan and Config objects for the pipeline. Serializes `zea.core.Object` instances and converts them to dictionary of tensors. Args: probe: Probe object. scan: Scan object. config: Config object. include (None, "all", or list): Only include these parameter/computed property names. If None or "all", include all. exclude (None or list): Exclude these parameter/computed property names. If provided, these keys will be excluded from the output. Only one of include or exclude can be set. **kwargs: Additional keyword arguments to be included in the inputs. Returns: dict: Dictionary of inputs with all values as tensors. """ # Initialize dictionaries for probe, scan, and config probe_dict, scan_dict, config_dict = {}, {}, {} # Process args to extract Probe, Scan, and Config objects if probe is not None: assert isinstance(probe, Probe), ( f"Expected an instance of `zea.probes.Probe`, got {type(probe)}" ) probe_dict = probe.to_tensor(keep_as_is=self.static_params) if scan is not None: assert isinstance(scan, Scan), ( f"Expected an instance of `zea.scan.Scan`, got {type(scan)}" ) scan_dict = scan.to_tensor(include=self.valid_keys, keep_as_is=self.static_params) if config is not None: assert isinstance(config, Config), ( f"Expected an instance of `zea.config.Config`, got {type(config)}" ) config_dict.update(config.to_tensor(keep_as_is=self.static_params)) # Convert all kwargs to tensors tensor_kwargs = dict_to_tensor(kwargs, keep_as_is=self.static_params) # combine probe, scan, config and kwargs # explicitly so we know which keys overwrite which # kwargs > config > scan > probe inputs = { **probe_dict, **scan_dict, **config_dict, **tensor_kwargs, } return inputs
[docs] def make_operation_chain( operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]], ) -> List[Operation]: """Make an operation chain from a custom list of operations. Args: operation_chain (list): List of operations to be performed. Each operation can be: - A string: operation initialized with default parameters - A dictionary: operation initialized with parameters in the dictionary - A Config object: converted to a dictionary and initialized - An Operation/Pipeline instance: used as-is Returns: list: List of operations to be performed. Example: .. code-block:: python chain = make_operation_chain( [ "envelope_detect", {"name": "normalize", "params": {"output_range": (0, 1)}}, SomeCustomOperation(), ] ) """ chain = [] for operation in operation_chain: # Handle already instantiated Operation or Pipeline objects if isinstance(operation, (Operation, Pipeline)): chain.append(operation) continue assert isinstance(operation, (str, dict, Config)), ( f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline" ) if isinstance(operation, str): operation_instance = get_ops(operation)() else: if isinstance(operation, Config): operation = operation.serialize() params = operation.get("params", {}) op_name = operation.get("name") operation_cls = get_ops(op_name) # Handle branches for branched pipeline if op_name == "branched_pipeline" and "branches" in operation: branch_configs = operation.get("branches", {}) branches = [] # Convert each branch configuration to an operation chain for _, branch_config in branch_configs.items(): if isinstance(branch_config, (list, np.ndarray)): # This is a list of operations branch = make_operation_chain(branch_config) elif "operations" in branch_config: # This is a pipeline-like branch branch = make_operation_chain(branch_config["operations"]) else: # This is a single operation branch branch_op_cls = get_ops(branch_config["name"]) branch_params = branch_config.get("params", {}) branch = branch_op_cls(**branch_params) branches.append(branch) # Create the branched pipeline instance operation_instance = operation_cls(branches=branches, **params) # Check for nested operations at the same level as params elif "operations" in operation: nested_operations = make_operation_chain(operation["operations"]) # Instantiate pipeline-type operations with nested operations if issubclass(operation_cls, Pipeline): operation_instance = operation_cls(operations=nested_operations, **params) else: operation_instance = operation_cls(operations=nested_operations, **params) elif operation["name"] in ["patched_grid"]: nested_operations = make_operation_chain(operation["params"].pop("operations")) operation_instance = operation_cls(operations=nested_operations, **params) else: operation_instance = operation_cls(**params) chain.append(operation_instance) return chain
[docs] def pipeline_from_config(config: Config, **kwargs) -> Pipeline: """ Create a Pipeline instance from a Config object. """ assert "operations" in config, ( "Config object must have an 'operations' key for pipeline creation." ) assert isinstance(config.operations, (list, np.ndarray)), ( "Config object must have a list or numpy array of operations for pipeline creation." ) operations = make_operation_chain(config.operations) # merge pipeline config without operations with kwargs pipeline_config = copy.deepcopy(config) pipeline_config.pop("operations") kwargs = {**pipeline_config, **kwargs} return Pipeline(operations=operations, **kwargs)
[docs] def pipeline_from_json(json_string: str, **kwargs) -> Pipeline: """ Create a Pipeline instance from a JSON string. """ pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON)) return pipeline_from_config(pipeline_config, **kwargs)
[docs] def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline: """ Create a Pipeline instance from a YAML file. """ with open(yaml_path, "r", encoding="utf-8") as f: pipeline_config = yaml.safe_load(f) operations = pipeline_config["operations"] return pipeline_from_config(Config({"operations": operations}), **kwargs)
[docs] def pipeline_to_config(pipeline: Pipeline) -> Config: """ Convert a Pipeline instance into a Config object. """ # TODO: we currently add the full pipeline as 1 operation to the config. # In another PR we should add a "pipeline" entry to the config instead of the "operations" # entry. This allows us to also have non-default pipeline classes as top level op. pipeline_dict = {"operations": [pipeline.get_dict()]} # HACK: If the top level operation is a single pipeline, collapse it into the operations list. ops = pipeline_dict["operations"] if ops[0]["name"] == "pipeline" and len(ops) == 1: pipeline_dict = {"operations": ops[0]["operations"]} return Config(pipeline_dict)
[docs] def pipeline_to_json(pipeline: Pipeline) -> str: """ Convert a Pipeline instance into a JSON string. """ pipeline_dict = {"operations": [pipeline.get_dict()]} # HACK: If the top level operation is a single pipeline, collapse it into the operations list. ops = pipeline_dict["operations"] if ops[0]["name"] == "pipeline" and len(ops) == 1: pipeline_dict = {"operations": ops[0]["operations"]} return json.dumps(pipeline_dict, cls=ZEAEncoderJSON, indent=4)
[docs] def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None: """ Convert a Pipeline instance into a YAML file. """ pipeline_dict = pipeline.get_dict() # HACK: If the top level operation is a single pipeline, collapse it into the operations list. ops = pipeline_dict["operations"] if ops[0]["name"] == "pipeline" and len(ops) == 1: pipeline_dict = {"operations": ops[0]["operations"]} with open(file_path, "w", encoding="utf-8") as f: yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
[docs] @ops_registry("patched_grid") class PatchedGrid(Pipeline): """ With this class you can form a pipeline that will be applied to patches of the grid. This is useful to avoid OOM errors when processing large grids. Somethings to NOTE about this class: - The ops have to use flatgrid and flat_pfield as inputs, these will be patched. - Changing anything other than `self.output_data_type` in the dict will not be propagated! - Will be jitted as a single operation, not the individual operations. - This class handles the batching. """ def __init__(self, *args, num_patches=10, **kwargs): super().__init__(*args, name="patched_grid", **kwargs) self.num_patches = num_patches for operation in self.operations: if isinstance(operation, DelayAndSum): operation.reshape_grid = False self._jittable_call = self.jittable_call @property def jit_options(self): """Get the jit_options property of the pipeline.""" return self._jit_options @jit_options.setter def jit_options(self, value): """Set the jit_options property of the pipeline.""" self._jit_options = value if value in ["pipeline", "ops"]: self.jit() else: self.unjit()
[docs] def jit(self): """JIT compile the pipeline.""" self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
[docs] def unjit(self): """Un-JIT compile the pipeline.""" self._jittable_call = self.jittable_call self._call_pipeline = self.call
@property def with_batch_dim(self): """Get the with_batch_dim property of the pipeline.""" return self._with_batch_dim @with_batch_dim.setter def with_batch_dim(self, value): """Set the with_batch_dim property of the pipeline. The class handles the batching so the operations have to be set to False.""" self._with_batch_dim = value for operation in self.operations: operation.with_batch_dim = False @property def valid_keys(self) -> set: """Get a set of valid keys for the pipeline. Adds the parameters that PatchedGrid itself operates on (even if not used by operations inside it).""" return super().valid_keys.union({"flatgrid", "grid_size_x", "grid_size_z"})
[docs] def call_item(self, inputs): """Process data in patches.""" # Extract necessary parameters # make sure to add those as valid keys above! grid_size_x = inputs["grid_size_x"] grid_size_z = inputs["grid_size_z"] flatgrid = inputs.pop("flatgrid") # TODO: maybe using n_tx and n_el from kwargs is better but these are tensors now # and this is not supported in broadcast_to n_tx = inputs[self.key].shape[0] n_pix = flatgrid.shape[0] n_el = inputs[self.key].shape[2] inputs["rx_apo"] = ops.broadcast_to(inputs.get("rx_apo", 1.0), (n_tx, n_pix, n_el)) inputs["rx_apo"] = ops.swapaxes(inputs["rx_apo"], 0, 1) # put n_pix first # Define a list of keys to look up for patching patch_keys = ["flat_pfield", "rx_apo"] patch_arrays = {} for key in patch_keys: if key in inputs: patch_arrays[key] = inputs.pop(key) def patched_call(flatgrid, **patch_kwargs): patch_args = {k: v for k, v in patch_kwargs.items() if v is not None} patch_args["rx_apo"] = ops.swapaxes(patch_args["rx_apo"], 0, 1) out = super(PatchedGrid, self).call(flatgrid=flatgrid, **patch_args, **inputs) return out[self.output_key] out = patched_map( patched_call, flatgrid, self.num_patches, **patch_arrays, jit=bool(self.jit_options), ) return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
[docs] def jittable_call(self, **inputs): """Process input data through the pipeline.""" if self._with_batch_dim: input_data = inputs.pop(self.key) output = ops.map( lambda x: self.call_item({self.key: x, **inputs}), input_data, ) else: output = self.call_item(inputs) return {self.output_key: output}
[docs] def call(self, **inputs): """Process input data through the pipeline.""" output = self._jittable_call(**inputs) inputs.update(output) return inputs
[docs] def get_dict(self): """Get the configuration of the pipeline.""" config = super().get_dict() config.update({"name": "patched_grid"}) config["params"].update({"num_patches": self.num_patches}) return config
## Base Operations
[docs] @ops_registry("identity") class Identity(Operation): """Identity operation."""
[docs] def call(self, **kwargs) -> Dict: """Returns the input as is.""" return kwargs
[docs] @ops_registry("merge") class Merge(Operation): """Operation that merges sets of input dictionaries.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.allow_multiple_inputs = True
[docs] def call(self, *args, **kwargs) -> Dict: """ Merges the input dictionaries. Priority is given to the last input. """ merged = {} for arg in args: if not isinstance(arg, dict): raise TypeError("All inputs must be dictionaries.") merged.update(arg) return merged
[docs] @ops_registry("split") class Split(Operation): """Operation that splits an input dictionary n copies.""" def __init__(self, n: int, **kwargs): super().__init__(**kwargs) self.n = n
[docs] def call(self, **kwargs) -> List[Dict]: """ Splits the input dictionary into n copies. """ return [kwargs.copy() for _ in range(self.n)]
[docs] @ops_registry("stack") class Stack(Operation): """Stack multiple data arrays along a new axis. Useful to merge data from parallel pipelines. """ def __init__( self, keys: Union[str, List[str], None], axes: Union[int, List[int], None], **kwargs, ): super().__init__(**kwargs) self.keys, self.axes = _assert_keys_and_axes(keys, axes)
[docs] def call(self, **kwargs) -> Dict: """ Stacks the inputs corresponding to the specified keys along the specified axis. If a list of axes is provided, the length must match the number of keys. """ for key, axis in zip(self.keys, self.axes): kwargs[key] = keras.ops.stack([kwargs[key] for key in self.keys], axis=axis) return kwargs
[docs] @ops_registry("mean") class Mean(Operation): """Take the mean of the input data along a specific axis.""" def __init__(self, keys, axes, **kwargs): super().__init__(**kwargs) self.keys, self.axes = _assert_keys_and_axes(keys, axes)
[docs] def call(self, **kwargs): for key, axis in zip(self.keys, self.axes): kwargs[key] = ops.mean(kwargs[key], axis=axis) return kwargs
[docs] @ops_registry("transpose") class Transpose(Operation): """Transpose the input data along the specified axes.""" def __init__(self, axes, **kwargs): super().__init__(**kwargs) self.axes = axes
[docs] def call(self, **kwargs): data = kwargs[self.key] transposed_data = ops.transpose(data, axes=self.axes) return {self.output_key: transposed_data}
[docs] @ops_registry("simulate_rf") class Simulate(Operation): """Simulate RF data.""" # Define operation-specific static parameters STATIC_PARAMS = ["n_ax", "apply_lens_correction"] def __init__(self, **kwargs): super().__init__( output_data_type=DataTypes.RAW_DATA, **kwargs, )
[docs] def call( self, scatterer_positions, scatterer_magnitudes, probe_geometry, apply_lens_correction, lens_thickness, lens_sound_speed, sound_speed, n_ax, center_frequency, sampling_frequency, t0_delays, initial_times, element_width, attenuation_coef, tx_apodizations, **kwargs, ): return { self.output_key: simulate_rf( ops.convert_to_tensor(scatterer_positions), ops.convert_to_tensor(scatterer_magnitudes), probe_geometry=probe_geometry, apply_lens_correction=apply_lens_correction, lens_thickness=lens_thickness, lens_sound_speed=lens_sound_speed, sound_speed=sound_speed, n_ax=n_ax, center_frequency=center_frequency, sampling_frequency=sampling_frequency, t0_delays=t0_delays, initial_times=initial_times, element_width=element_width, attenuation_coef=attenuation_coef, tx_apodizations=tx_apodizations, ), "n_ch": 1, # Simulate always returns RF data (so single channel) }
[docs] @ops_registry("tof_correction") class TOFCorrection(Operation): """Time-of-flight correction operation for ultrasound data.""" # Define operation-specific static parameters STATIC_PARAMS = [ "f_number", "apply_lens_correction", "apply_phase_rotation", "grid_size_x", "grid_size_z", ] def __init__(self, apply_phase_rotation=True, **kwargs): super().__init__( input_data_type=DataTypes.RAW_DATA, output_data_type=DataTypes.ALIGNED_DATA, **kwargs, ) self.apply_phase_rotation = apply_phase_rotation
[docs] def call( self, flatgrid, sound_speed, polar_angles, focus_distances, sampling_frequency, f_number, demodulation_frequency, t0_delays, tx_apodizations, initial_times, probe_geometry, apply_lens_correction=None, lens_thickness=None, lens_sound_speed=None, **kwargs, ): """Perform time-of-flight correction on raw RF data. Args: raw_data (ops.Tensor): Raw RF data to correct flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight sound_speed (float): Sound speed in the medium polar_angles (ops.Tensor): Polar angles for scan lines focus_distances (ops.Tensor): Focus distances for scan lines sampling_frequency (float): Sampling frequency f_number (float): F-number for apodization demodulation_frequency (float): Demodulation frequency t0_delays (ops.Tensor): T0 delays tx_apodizations (ops.Tensor): Transmit apodizations initial_times (ops.Tensor): Initial times probe_geometry (ops.Tensor): Probe element positions apply_lens_correction (bool): Whether to apply lens correction lens_thickness (float): Lens thickness lens_sound_speed (float): Sound speed in the lens Returns: dict: Dictionary containing tof_corrected_data """ raw_data = kwargs[self.key] kwargs = { "flatgrid": flatgrid, "sound_speed": sound_speed, "angles": polar_angles, "focus_distances": focus_distances, "sampling_frequency": sampling_frequency, "fnum": f_number, "apply_phase_rotation": self.apply_phase_rotation, "demodulation_frequency": demodulation_frequency, "t0_delays": t0_delays, "tx_apodizations": tx_apodizations, "initial_times": initial_times, "probe_geometry": probe_geometry, "apply_lens_correction": apply_lens_correction, "lens_thickness": lens_thickness, "lens_sound_speed": lens_sound_speed, } if not self.with_batch_dim: tof_corrected = tof_correction(raw_data, **kwargs) else: tof_corrected = ops.map( lambda data: tof_correction(data, **kwargs), raw_data, ) return {self.output_key: tof_corrected}
[docs] @ops_registry("pfield_weighting") class PfieldWeighting(Operation): """Weighting aligned data with the pressure field.""" def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.ALIGNED_DATA, **kwargs, )
[docs] def call(self, flat_pfield=None, **kwargs): """Weight data with pressure field. Args: flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx) Returns: dict: Dictionary containing weighted data """ data = kwargs[self.key] if flat_pfield is None: return {self.output_key: data} # Swap (n_pix, n_tx) to (n_tx, n_pix) flat_pfield = ops.swapaxes(flat_pfield, 0, 1) # Perform element-wise multiplication with the pressure weight mask # Also add the required dimensions for broadcasting if self.with_batch_dim: pfield_expanded = ops.expand_dims(flat_pfield, axis=0) else: pfield_expanded = flat_pfield pfield_expanded = pfield_expanded[..., None, None] weighted_data = data * pfield_expanded return {self.output_key: weighted_data}
[docs] @ops_registry("sum") class Sum(Operation): """Sum data along a specific axis.""" def __init__(self, axis, **kwargs): super().__init__(**kwargs) self.axis = axis
[docs] def call(self, **kwargs): data = kwargs[self.key] return {self.output_key: ops.sum(data, axis=self.axis)}
[docs] @ops_registry("delay_and_sum") class DelayAndSum(Operation): """Sums time-delayed signals along channels and transmits.""" def __init__( self, reshape_grid=True, **kwargs, ): super().__init__( input_data_type=None, output_data_type=DataTypes.BEAMFORMED_DATA, **kwargs, ) self.reshape_grid = reshape_grid
[docs] def process_image(self, data, rx_apo): """Performs DAS beamforming on tof-corrected input. Args: data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)` rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, n_pix, n_el, n_ch)`. Returns: ops.Tensor: The beamformed data of shape `(n_pix, n_ch)` """ # Sum over the channels, i.e. DAS data = ops.sum(rx_apo * data, -2) # Sum over transmits, i.e. Compounding data = ops.sum(data, 0) return data
[docs] def call( self, rx_apo=None, grid=None, **kwargs, ): """Performs DAS beamforming on tof-corrected input. Args: tof_corrected_data (ops.Tensor): The TOF corrected input of shape `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension. rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, grid_size_z*grid_size_x, n_el)` with optional batch dimension. Defaults to 1.0. Returns: dict: Dictionary containing beamformed_data of shape `(grid_size_z*grid_size_x, n_ch)` when reshape_grid is False or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True, with optional batch dimension. """ data = kwargs[self.key] if rx_apo is None: rx_apo = ops.ones(1, dtype=ops.dtype(data)) rx_apo = ops.broadcast_to(rx_apo[..., None], data.shape) if not self.with_batch_dim: beamformed_data = self.process_image(data, rx_apo) else: # Apply process_image to each item in the batch beamformed_data = batched_map( lambda data, rx_apo: self.process_image(data, rx_apo), data, rx_apo=rx_apo ) if self.reshape_grid: beamformed_data = reshape_axis( beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim) ) return {self.output_key: beamformed_data}
[docs] @ops_registry("envelope_detect") class EnvelopeDetect(Operation): """Envelope detection of RF signals.""" def __init__( self, axis=-3, **kwargs, ): super().__init__( input_data_type=DataTypes.BEAMFORMED_DATA, output_data_type=DataTypes.ENVELOPE_DATA, **kwargs, ) self.axis = axis
[docs] def call(self, **kwargs): """ Args: - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch). Returns: - envelope_data (Tensor): The envelope detected data of shape (..., grid_size_z, grid_size_x). """ data = kwargs[self.key] if data.shape[-1] == 2: data = channels_to_complex(data) else: n_ax = data.shape[self.axis] M = 2 ** int(np.ceil(np.log2(n_ax))) # data = scipy.signal.hilbert(data, N=M, axis=self.axis) data = hilbert(data, N=M, axis=self.axis) indices = ops.arange(n_ax) data = ops.take(data, indices, axis=self.axis) data = ops.squeeze(data, axis=-1) # data = ops.abs(data) real = ops.real(data) imag = ops.imag(data) data = ops.sqrt(real**2 + imag**2) data = ops.cast(data, "float32") return {self.output_key: data}
[docs] @ops_registry("upmix") class UpMix(Operation): """Upmix IQ data to RF data.""" def __init__( self, upsampling_rate=1, **kwargs, ): super().__init__( **kwargs, ) self.upsampling_rate = upsampling_rate
[docs] def call( self, sampling_frequency=None, center_frequency=None, **kwargs, ): data = kwargs[self.key] if data.shape[-1] == 1: log.warning("Upmixing is not applicable to RF data.") return data elif data.shape[-1] == 2: data = channels_to_complex(data) data = upmix(data, sampling_frequency, center_frequency, self.upsampling_rate) data = ops.expand_dims(data, axis=-1) return {self.output_key: data}
[docs] @ops_registry("log_compress") class LogCompress(Operation): """Logarithmic compression of data.""" def __init__( self, **kwargs, ): super().__init__( input_data_type=DataTypes.ENVELOPE_DATA, output_data_type=DataTypes.IMAGE, **kwargs, )
[docs] def call(self, dynamic_range=None, **kwargs): """Apply logarithmic compression to data. Args: dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0). Returns: dict: Dictionary containing log-compressed data """ data = kwargs[self.key] if dynamic_range is None: dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE) dynamic_range = ops.cast(dynamic_range, data.dtype) small_number = ops.convert_to_tensor(1e-16, dtype=data.dtype) data = ops.where(data == 0, small_number, data) compressed_data = 20 * ops.log10(data) compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1]) return {self.output_key: compressed_data}
[docs] @ops_registry("normalize") class Normalize(Operation): """Normalize data to a given range.""" def __init__(self, output_range=None, input_range=None, **kwargs): super().__init__(**kwargs) if output_range is None: output_range = (0, 1) self.output_range = self.to_float32(output_range) self.input_range = self.to_float32(input_range) assert output_range is None or len(output_range) == 2 assert input_range is None or len(input_range) == 2
[docs] @staticmethod def to_float32(data): """Converts an iterable to float32 and leaves None values as is.""" return ( [np.float32(x) if x is not None else None for x in data] if data is not None else None )
[docs] def call(self, **kwargs): """Normalize data to a given range. Args: output_range (tuple, optional): Range to which data should be mapped. Defaults to (0, 1). input_range (tuple, optional): Range of input data. If None, the range of the input data will be computed. Defaults to None. Returns: dict: Dictionary containing normalized data, along with the computed or provided input range (minval and maxval). """ data = kwargs[self.key] # If input_range is not provided, try to get it from kwargs # This allows you to normalize based on the first frame in a sequence and avoid flicker if self.input_range is None: maxval = kwargs.get("maxval", None) minval = kwargs.get("minval", None) # If input_range is provided, use it else: minval, maxval = self.input_range # If input_range is still not provided, compute it from the data if minval is None: minval = ops.min(data) if maxval is None: maxval = ops.max(data) # Clip the data to the input range data = ops.clip(data, minval, maxval) # Map the data to the output range normalized_data = translate(data, (minval, maxval), self.output_range) return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
[docs] @ops_registry("scan_convert") class ScanConvert(Operation): """Scan convert images to cartesian coordinates.""" STATIC_PARAMS = ["fill_value"] def __init__(self, order=1, **kwargs): """Initialize the ScanConvert operation. Args: order (int, optional): Interpolation order. Defaults to 1. Currently only GPU support for order=1. """ if order > 1: jittable = False log.warning( "GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert." ) else: jittable = True super().__init__( input_data_type=DataTypes.IMAGE, output_data_type=DataTypes.IMAGE_SC, jittable=jittable, **kwargs, ) self.order = order
[docs] def call( self, rho_range=None, theta_range=None, phi_range=None, resolution=None, coordinates=None, fill_value=None, **kwargs, ): """Scan convert images to cartesian coordinates. Args: rho_range (Tuple): Range of the rho axis in the polar coordinate system. Defined in meters. theta_range (Tuple): Range of the theta axis in the polar coordinate system. Defined in radians. phi_range (Tuple): Range of the phi axis in the polar coordinate system. Defined in radians. resolution (float): Resolution of the output image in meters per pixel. if None, the resolution is computed based on the input data. coordinates (Tensor): Coordinates for scan convertion. If None, will be computed based on rho_range, theta_range, phi_range and resolution. If provided, this operation can be jitted. fill_value (float): Value to fill the image with outside the defined region. """ if fill_value is None: fill_value = np.nan data = kwargs[self.key] if self._jit_compile and self.jittable: assert coordinates is not None, ( "coordinates must be provided to jit scan conversion." "You can set ScanConvert(jit_compile=False) to disable jitting." ) data_out, parameters = scan_convert( data, rho_range, theta_range, phi_range, resolution, coordinates, fill_value, self.order, with_batch_dim=self.with_batch_dim, ) return {self.output_key: data_out, **parameters}
[docs] @ops_registry("gaussian_blur") class GaussianBlur(Operation): """ GaussianBlur is an operation that applies a Gaussian blur to an input image. Uses scipy.ndimage.gaussian_filter to create a kernel. """ def __init__( self, sigma: float, kernel_size: int | None = None, pad_mode="symmetric", truncate=4.0, **kwargs, ): """ Args: sigma (float): Standard deviation for Gaussian kernel. kernel_size (int, optional): The size of the kernel. If None, the kernel size is calculated based on the sigma and truncate. Default is None. pad_mode (str): Padding mode for the input image. Default is 'symmetric'. truncate (float): Truncate the filter at this many standard deviations. """ super().__init__(**kwargs) if kernel_size is None: radius = round(truncate * sigma) self.kernel_size = 2 * radius + 1 else: self.kernel_size = kernel_size self.sigma = sigma self.pad_mode = pad_mode self.radius = self.kernel_size // 2 self.kernel = self.get_kernel()
[docs] def get_kernel(self): """ Create a gaussian kernel for blurring. Returns: kernel (Tensor): A gaussian kernel for blurring. Shape is (kernel_size, kernel_size, 1, 1). """ n = np.zeros((self.kernel_size, self.kernel_size)) n[self.radius, self.radius] = 1 kernel = scipy.ndimage.gaussian_filter(n, sigma=self.sigma, mode="constant").astype( np.float32 ) kernel = kernel[:, :, None, None] return ops.convert_to_tensor(kernel)
[docs] def call(self, **kwargs): data = kwargs[self.key] # Add batch dimension if not present if not self.with_batch_dim: data = data[None] # Add channel dimension to kernel kernel = ops.tile(self.kernel, (1, 1, data.shape[-1], data.shape[-1])) # Pad the input image according to the padding mode padded = ops.pad( data, [[0, 0], [self.radius, self.radius], [self.radius, self.radius], [0, 0]], mode=self.pad_mode, ) # Apply the gaussian kernel to the padded image out = ops.conv(padded, kernel, padding="valid", data_format="channels_last") # Remove padding out = ops.slice( out, [0, 0, 0, 0], [out.shape[0], data.shape[1], data.shape[2], data.shape[3]], ) # Remove batch dimension if it was not present before if not self.with_batch_dim: out = ops.squeeze(out, axis=0) return {self.output_key: out}
[docs] @ops_registry("lee_filter") class LeeFilter(Operation): """ The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR) and ultrasound image processing. It smooths the image while preserving edges and details. This implementation uses Gaussian filter for local statistics and treats channels independently. Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics. IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168. """ def __init__(self, sigma=3, kernel_size=None, pad_mode="symmetric", **kwargs): """ Args: sigma (float): Standard deviation for Gaussian kernel. Default is 3. kernel_size (int, optional): Size of the Gaussian kernel. If None, it will be calculated based on sigma. pad_mode (str): Padding mode to be used for Gaussian blur. Default is "symmetric". """ super().__init__(**kwargs) self.sigma = sigma self.kernel_size = kernel_size self.pad_mode = pad_mode # Create a GaussianBlur instance for computing local statistics self.gaussian_blur = GaussianBlur( sigma=self.sigma, kernel_size=self.kernel_size, pad_mode=self.pad_mode, with_batch_dim=self.with_batch_dim, jittable=self._jittable, key=self.key, ) @property def with_batch_dim(self): """Get the with_batch_dim property of the LeeFilter operation.""" return self._with_batch_dim @with_batch_dim.setter def with_batch_dim(self, value): """Set the with_batch_dim property of the LeeFilter operation.""" self._with_batch_dim = value if hasattr(self, "gaussian_blur"): self.gaussian_blur.with_batch_dim = value
[docs] def call(self, **kwargs): data = kwargs[self.key] # Apply Gaussian blur to get local mean img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key] # Apply Gaussian blur to squared data to get local squared mean data_squared = data**2 kwargs[self.gaussian_blur.key] = data_squared img_sqr_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key] # Calculate local variance img_variance = img_sqr_mean - img_mean**2 # Calculate global variance (per channel) if self.with_batch_dim: overall_variance = ops.var(data, axis=(-3, -2), keepdims=True) else: overall_variance = ops.var(data, axis=(-2, -1), keepdims=True) # Calculate adaptive weights img_weights = img_variance / (img_variance + overall_variance) # Apply Lee filter formula img_output = img_mean + img_weights * (data - img_mean) return {self.output_key: img_output}
[docs] @ops_registry("demodulate") class Demodulate(Operation): """Demodulates the input data to baseband. After this operation, the carrier frequency is removed (0 Hz) and the data is in IQ format stored in two real valued channels.""" def __init__(self, axis=-3, **kwargs): super().__init__( input_data_type=DataTypes.RAW_DATA, output_data_type=DataTypes.RAW_DATA, jittable=True, **kwargs, ) self.axis = axis
[docs] def call(self, center_frequency=None, sampling_frequency=None, **kwargs): data = kwargs[self.key] demodulation_frequency = center_frequency # Split the complex signal into two channels iq_data_two_channel = demodulate( data=data, center_frequency=center_frequency, sampling_frequency=sampling_frequency, axis=self.axis, ) return { self.output_key: iq_data_two_channel, "demodulation_frequency": demodulation_frequency, "center_frequency": 0.0, "n_ch": 2, }
[docs] @ops_registry("lambda") class Lambda(Operation): """Use any function as an operation.""" def __init__(self, func, func_kwargs=None, **kwargs): super().__init__(**kwargs) func_kwargs = func_kwargs or {} self.func = partial(func, **func_kwargs)
[docs] def call(self, **kwargs): data = kwargs[self.key] data = self.func(data) return {self.output_key: data}
[docs] @ops_registry("clip") class Clip(Operation): """Clip the input data to a given range.""" def __init__(self, min_value=None, max_value=None, **kwargs): super().__init__(**kwargs) self.min_value = min_value self.max_value = max_value
[docs] def call(self, **kwargs): data = kwargs[self.key] data = ops.clip(data, self.min_value, self.max_value) return {self.output_key: data}
[docs] @ops_registry("pad") class Pad(Operation, TFDataLayer): """Pad layer for padding tensors to a specified shape.""" def __init__( self, target_shape: list | tuple, uniform: bool = True, axis: Union[int, List[int]] = None, fail_on_bigger_shape: bool = True, pad_kwargs: dict = None, **kwargs, ): super().__init__(**kwargs) self.target_shape = target_shape self.uniform = uniform self.axis = axis self.pad_kwargs = pad_kwargs or {} self.fail_on_bigger_shape = fail_on_bigger_shape @staticmethod def _format_target_shape(shape_array, target_shape, axis): if isinstance(axis, int): axis = [axis] assert len(axis) == len(target_shape), ( "The length of axis must be equal to the length of target_shape." ) axis = map_negative_indices(axis, len(shape_array)) target_shape = [ target_shape[axis.index(i)] if i in axis else shape_array[i] for i in range(len(shape_array)) ] return target_shape
[docs] def pad( self, z, target_shape: list | tuple, uniform: bool = True, axis: Union[int, List[int]] = None, fail_on_bigger_shape: bool = True, **kwargs, ): """ Pads the input tensor `z` to the specified shape. Parameters: z (tensor): The input tensor to be padded. target_shape (list or tuple): The target shape to pad the tensor to. uniform (bool, optional): If True, ensures that padding is uniform (even on both sides). Default is False. axis (int or list of int, optional): The axis or axes along which `target_shape` was specified. If None, `len(target_shape) == `len(ops.shape(z))` must hold. Default is None. fail_on_bigger_shape (bool, optional): If True (default), raises an error if any target dimension is smaller than the input shape; if False, pads only where the target shape exceeds the input shape and leaves other dimensions unchanged. kwargs: Additional keyword arguments to pass to the padding function. Returns: tensor: The padded tensor with the specified shape. """ shape_array = self.backend.shape(z) # When axis is provided, convert target_shape if axis is not None: target_shape = self._format_target_shape(shape_array, target_shape, axis) if not fail_on_bigger_shape: target_shape = [max(target_shape[i], shape_array[i]) for i in range(len(shape_array))] # Compute the padding required for each dimension pad_shape = np.array(target_shape) - shape_array # Create the paddings array if uniform: # if odd, pad more on the left, same as: # https://keras.io/api/layers/preprocessing_layers/image_preprocessing/center_crop/ right_pad = pad_shape // 2 left_pad = pad_shape - right_pad paddings = np.stack([right_pad, left_pad], axis=1) else: paddings = np.stack([np.zeros_like(pad_shape), pad_shape], axis=1) if np.any(paddings < 0): raise ValueError( f"Target shape {target_shape} must be greater than or equal " f"to the input shape {shape_array}." ) return self.backend.numpy.pad(z, paddings, **kwargs)
[docs] def call(self, **kwargs): data = kwargs[self.key] padded_data = self.pad( data, self.target_shape, self.uniform, self.axis, self.fail_on_bigger_shape, **self.pad_kwargs, ) return {self.output_key: padded_data}
[docs] @ops_registry("companding") class Companding(Operation): """Companding according to the A- or μ-law algorithm. Invertible compressing operation. Used to compress dynamic range of input data (and subsequently expand). μ-law companding: https://en.wikipedia.org/wiki/%CE%9C-law_algorithm A-law companding: https://en.wikipedia.org/wiki/A-law_algorithm Args: expand (bool, optional): If set to False (default), data is compressed, else expanded. comp_type (str): either `a` or `mu`. mu (float, optional): compression parameter. Defaults to 255. A (float, optional): compression parameter. Defaults to 87.6. """ def __init__(self, expand=False, comp_type="mu", **kwargs): super().__init__(**kwargs) self.expand = expand self.comp_type = comp_type.lower() if self.comp_type not in ["mu", "a"]: raise ValueError("comp_type must be 'mu' or 'a'.") if self.comp_type == "mu": self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress else: self._compand_func = self._a_law_expand if self.expand else self._a_law_compress @staticmethod def _mu_law_compress(x, mu=255, **kwargs): x = ops.clip(x, -1, 1) return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu) @staticmethod def _mu_law_expand(y, mu=255, **kwargs): y = ops.clip(y, -1, 1) return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu @staticmethod def _a_law_compress(x, A=87.6, **kwargs): x = ops.clip(x, -1, 1) x_sign = ops.sign(x) x_abs = ops.abs(x) A_log = ops.log(A) val1 = x_sign * A * x_abs / (1.0 + A_log) val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log) y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2) return y @staticmethod def _a_law_expand(y, A=87.6, **kwargs): y = ops.clip(y, -1, 1) y_sign = ops.sign(y) y_abs = ops.abs(y) A_log = ops.log(A) val1 = y_sign * y_abs * (1.0 + A_log) / A val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2) return x
[docs] def call(self, mu=255, A=87.6, **kwargs): data = kwargs[self.key] mu = ops.cast(mu, data.dtype) A = ops.cast(A, data.dtype) data_out = self._compand_func(data, mu=mu, A=A) return {self.output_key: data_out}
[docs] @ops_registry("downsample") class Downsample(Operation): """Downsample data along a specific axis.""" def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs): super().__init__( **kwargs, ) self.factor = factor self.phase = phase self.axis = axis
[docs] def call(self, sampling_frequency=None, n_ax=None, **kwargs): data = kwargs[self.key] length = ops.shape(data)[self.axis] sample_idx = ops.arange(self.phase, length, self.factor) data_downsampled = ops.take(data, sample_idx, axis=self.axis) output = {self.output_key: data_downsampled} # downsampling also affects the sampling frequency if sampling_frequency is not None: sampling_frequency = sampling_frequency / self.factor output["sampling_frequency"] = sampling_frequency if n_ax is not None: n_ax = n_ax // self.factor output["n_ax"] = n_ax return output
[docs] @ops_registry("branched_pipeline") class BranchedPipeline(Operation): """Operation that processes data through multiple branches. This operation takes input data, processes it through multiple parallel branches, and then merges the results from those branches using the specified merge strategy. """ def __init__(self, branches=None, merge_strategy="nested", **kwargs): """Initialize a branched pipeline. Args: branches (List[Union[List, Pipeline, Operation]]): List of branch operations merge_strategy (str or callable): How to merge the outputs from branches: - "nested" (default): Return outputs as a dictionary keyed by branch name - "flatten": Flatten outputs by prefixing keys with the branch name - "suffix": Flatten outputs by suffixing keys with the branch name - callable: A custom merge function that accepts the branch outputs dict **kwargs: Additional arguments for the Operation base class """ super().__init__(**kwargs) # Convert branch specifications to operation chains if branches is None: branches = [] self.branches = {} for i, branch in enumerate(branches, start=1): branch_name = f"branch_{i}" # Convert different branch specification types if isinstance(branch, list): # Convert list to operation chain self.branches[branch_name] = make_operation_chain(branch) elif isinstance(branch, (Pipeline, Operation)): # Already a pipeline or operation self.branches[branch_name] = branch else: raise ValueError( f"Branch must be a list, Pipeline, or Operation, got {type(branch)}" ) # Set merge strategy self.merge_strategy = merge_strategy if isinstance(merge_strategy, str): if merge_strategy == "nested": self._merge_function = lambda outputs: outputs elif merge_strategy == "flatten": self._merge_function = self.flatten_outputs elif merge_strategy == "suffix": self._merge_function = self.suffix_merge_outputs else: raise ValueError(f"Unknown merge_strategy: {merge_strategy}") elif callable(merge_strategy): self._merge_function = merge_strategy else: raise ValueError("Invalid merge_strategy type provided.")
[docs] def call(self, **kwargs): """Process input through branches and merge results. Args: **kwargs: Input keyword arguments Returns: dict: Merged outputs from all branches according to merge strategy """ branch_outputs = {} for branch_name, branch in self.branches.items(): # Each branch gets a fresh copy of kwargs to avoid interference branch_kwargs = kwargs.copy() # Process through the branch branch_result = branch(**branch_kwargs) # Store branch outputs branch_outputs[branch_name] = branch_result # Apply merge strategy to combine outputs merged_outputs = self._merge_function(branch_outputs) return merged_outputs
[docs] def flatten_outputs(self, outputs: dict) -> dict: """ Flatten a nested dictionary by prefixing keys with the branch name. For each branch, the resulting key is "{branch_name}_{original_key}". """ flat = {} for branch_name, branch_dict in outputs.items(): for key, value in branch_dict.items(): new_key = f"{branch_name}_{key}" if new_key in flat: raise ValueError(f"Key collision detected for {new_key}") flat[new_key] = value return flat
[docs] def suffix_merge_outputs(self, outputs: dict) -> dict: """ Flatten a nested dictionary by suffixing keys with the branch name. For each branch, the resulting key is "{original_key}_{branch_name}". """ flat = {} for branch_name, branch_dict in outputs.items(): for key, value in branch_dict.items(): new_key = f"{key}_{branch_name}" if new_key in flat: raise ValueError(f"Key collision detected for {new_key}") flat[new_key] = value return flat
[docs] def get_config(self): """Return the config dictionary for serialization.""" config = super().get_config() # Add branch configurations branch_configs = {} for branch_name, branch in self.branches.items(): if isinstance(branch, Pipeline): # Get the operations list from the Pipeline branch_configs[branch_name] = branch.get_config() elif isinstance(branch, list): # Convert list of operations to list of operation configs branch_op_configs = [] for op in branch: branch_op_configs.append(op.get_config()) branch_configs[branch_name] = {"operations": branch_op_configs} else: # Single operation branch_configs[branch_name] = branch.get_config() # Add merge strategy if isinstance(self.merge_strategy, str): merge_strategy_config = self.merge_strategy else: # For custom functions, use the name if available merge_strategy_config = getattr(self.merge_strategy, "__name__", "custom") config.update( { "branches": branch_configs, "merge_strategy": merge_strategy_config, } ) return config
[docs] def get_dict(self): """Get the configuration of the operation.""" config = super().get_dict() config.update({"name": "branched_pipeline"}) # Add branches (recursively) to the config branches = {} for branch_name, branch in self.branches.items(): if isinstance(branch, Pipeline): branches[branch_name] = branch.get_dict() elif isinstance(branch, list): branches[branch_name] = [op.get_dict() for op in branch] else: branches[branch_name] = branch.get_dict() config["branches"] = branches config["merge_strategy"] = self.merge_strategy return config
[docs] @ops_registry("threshold") class Threshold(Operation): """Threshold an array, setting values below/above a threshold to a fill value.""" def __init__( self, threshold_type="hard", below_threshold=True, fill_value="min", **kwargs, ): super().__init__(**kwargs) if threshold_type not in ("hard", "soft"): raise ValueError("threshold_type must be 'hard' or 'soft'") self.threshold_type = threshold_type self.below_threshold = below_threshold self._fill_value_type = fill_value # Define threshold function at init if threshold_type == "hard": if below_threshold: self._threshold_func = lambda data, threshold, fill: ops.where( data < threshold, fill, data ) else: self._threshold_func = lambda data, threshold, fill: ops.where( data > threshold, fill, data ) else: # soft if below_threshold: self._threshold_func = ( lambda data, threshold, fill: ops.maximum(data - threshold, 0) + fill ) else: self._threshold_func = ( lambda data, threshold, fill: ops.minimum(data - threshold, 0) + fill ) def _resolve_fill_value(self, data, threshold): """Get the fill value based on the fill_value_type.""" fv = self._fill_value_type if isinstance(fv, (int, float)): return ops.convert_to_tensor(fv, dtype=data.dtype) elif fv == "min": return ops.min(data) elif fv == "max": return ops.max(data) elif fv == "threshold": return threshold else: raise ValueError("Unknown fill_value")
[docs] def call( self, threshold=None, percentile=None, **kwargs, ): """Threshold the input data. Args: threshold: Numeric threshold. percentile: Percentile to derive threshold from. Returns: Tensor with thresholding applied. """ data = kwargs[self.key] if (threshold is None) == (percentile is None): raise ValueError("Pass either threshold or percentile, not both or neither.") if percentile is not None: # Convert percentile to quantile value (0-1 range) threshold = ops.quantile(data, percentile / 100.0) fill_value = self._resolve_fill_value(data, threshold) result = self._threshold_func(data, threshold, fill_value) return {self.output_key: result}
[docs] @ops_registry("anisotropic_diffusion") class AnisotropicDiffusion(Operation): """Speckle Reducing Anisotropic Diffusion (SRAD) filter. Reference: - https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion - https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox """
[docs] def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs): """Anisotropic diffusion filter. Assumes input data is non-negative. Args: niter: Number of iterations. lmbda: Lambda parameter. rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional). eps: Small epsilon for stability. Returns: Filtered image (2D tensor or batch of images). """ data = kwargs[self.key] if not self.with_batch_dim: data = ops.expand_dims(data, axis=0) batch_size = ops.shape(data)[0] results = [] for i in range(batch_size): image = data[i] image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps) results.append(image_out) result = ops.stack(results, axis=0) if not self.with_batch_dim: result = ops.squeeze(result, axis=0) return {self.output_key: result}
def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps): """Apply anisotropic diffusion to a single image (2D).""" image = ops.exp(image) M, N = image.shape for _ in range(niter): iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0) iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0) jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1) jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1) if rect is not None: x1, y1, x2, y2 = rect imageuniform = image[x1:x2, y1:y2] q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2 dN = iN - image dS = iS - image dW = jW - image dE = jE - image G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps) L = (dN + dS + dW + dE) / (image + eps) num = (0.5 * G2) - ((1 / 16) * (L**2)) den = (1 + ((1 / 4) * L)) ** 2 q_squared = num / (den + eps) if rect is not None: den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps) c = 1.0 / (1 + den) cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0) cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1) D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW) image = image + (lmbda / 4) * D result = ops.log(image) return result
[docs] class ChannelsToComplex(Operation):
[docs] def call(self, **kwargs): data = kwargs[self.key] output = channels_to_complex(data) return {self.output_key: output}
[docs] class ComplexToChannels(Operation): def __init__(self, axis=-1, **kwargs): super().__init__(**kwargs) self.axis = axis
[docs] def call(self, **kwargs): data = kwargs[self.key] output = complex_to_channels(data, axis=self.axis) return {self.output_key: output}
[docs] def demodulate_not_jitable( rf_data, sampling_frequency=None, center_frequency=None, bandwidth=None, filter_coeff=None, ): """Demodulates an RF signal to complex base-band (IQ). Demodulates the radiofrequency (RF) bandpass signals and returns the Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary) part contains the in-phase (quadrature) component. This function operates (i.e. demodulates) on the RF signal over the (fast-) time axis which is assumed to be the last axis. Args: rf_data (ndarray): real valued input array of size [..., n_ax, n_el]. second to last axis is fast-time axis. sampling_frequency (float): the sampling frequency of the RF signals (in Hz). Only not necessary when filter_coeff is provided. center_frequency (float, optional): represents the center frequency (in Hz). Defaults to None. bandwidth (float, optional): Bandwidth of RF signal in % of center frequency. Defaults to None. The bandwidth in % is defined by: B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency). The cutoff frequency: Wn = Bandwidth_in_Hz/sampling_frequency, i.e: Wn = B*(center_frequency/100)/sampling_frequency. filter_coeff (list, optional): (b, a), numerator and denominator coefficients of FIR filter for quadratic band pass filter. All other parameters are ignored if filter_coeff are provided. Instead the given filter_coeff is directly used. If not provided, a filter is derived from the other params (sampling_frequency, center_frequency, bandwidth). see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html Returns: iq_data (ndarray): complex valued base-band signal. """ rf_data = ops.convert_to_numpy(rf_data) assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}" input_shape = rf_data.shape n_dim = len(input_shape) if n_dim > 2: *_, n_ax, n_el = input_shape else: n_ax, n_el = input_shape if filter_coeff is None: assert sampling_frequency is not None, "provide sampling_frequency when no filter is given." # Time vector t = np.arange(n_ax) / sampling_frequency t0 = 0 t = t + t0 # Estimate center frequency if center_frequency is None: # Keep a maximum of 100 randomly selected scanlines idx = np.arange(n_el) if n_el > 100: idx = np.random.permutation(idx)[:100] # Power Spectrum P = np.sum( np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2, axis=-1, ) P = P[: n_ax // 2] # Carrier frequency idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P) center_frequency = idx * sampling_frequency / n_ax # Normalized cut-off frequency if bandwidth is None: Wn = min(2 * center_frequency / sampling_frequency, 0.5) bandwidth = center_frequency * Wn else: assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar." assert (bandwidth > 0) & (bandwidth <= 200), ( "The signal bandwidth (in %) must be within the interval of ]0,200]." ) # bandwidth in Hz bandwidth = center_frequency * bandwidth / 100 Wn = bandwidth / sampling_frequency assert (Wn > 0) & (Wn <= 1), ( "The normalized cutoff frequency is not within the interval of (0,1). " "Check the input parameters!" ) # Down-mixing of the RF signals carrier = np.exp(-1j * 2 * np.pi * center_frequency * t) # add the singleton dimensions carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1)) iq_data = rf_data * carrier # Low-pass filter N = 5 b, a = scipy.signal.butter(N, Wn, "low") # factor 2: to preserve the envelope amplitude iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2 # Display a warning message if harmful aliasing is suspected # the RF signal is undersampled if sampling_frequency < (2 * center_frequency + bandwidth): # lower and higher frequencies of the bandpass signal fL = center_frequency - bandwidth / 2 fH = center_frequency + bandwidth / 2 n = fH // (fH - fL) harmless_aliasing = any( (2 * fH / np.arange(1, n) <= sampling_frequency) & (sampling_frequency <= 2 * fL / np.arange(1, n)) ) if not harmless_aliasing: log.warning( "rf2iq:harmful_aliasing Harmful aliasing is present: the aliases" " are not mutually exclusive!" ) else: b, a = filter_coeff iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2 return iq_data
[docs] def upmix(iq_data, sampling_frequency, center_frequency, upsampling_rate=6): """Upsamples and upmixes complex base-band signals (IQ) to RF. Args: iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second to last axis is fast-time axis. sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz). resulting sampling_frequency of RF data is upsampling_rate times higher. center_frequency (float, optional): represents the center frequency (in Hz). Returns: rf_data (ndarray): output real valued rf data. """ assert iq_data.dtype in [ "complex64", "complex128", ], "IQ must contain all complex signals." input_shape = iq_data.shape n_dim = len(input_shape) if n_dim > 2: *_, n_ax, _ = input_shape else: n_ax, _ = input_shape # Time vector n_ax_up = n_ax * upsampling_rate sampling_frequency_up = sampling_frequency * upsampling_rate t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up t0 = 0 t = t + t0 iq_data_upsampled = resample( iq_data, n_samples=n_ax_up, axis=-2, order=1, ) # Up-mixing of the IQ signals t = ops.cast(t, dtype="complex64") center_frequency = ops.cast(center_frequency, dtype="complex64") carrier = ops.exp(1j * 2 * np.pi * center_frequency * t) carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1)) rf_data = iq_data_upsampled * carrier rf_data = ops.real(rf_data) * ops.sqrt(2) return ops.cast(rf_data, "float32")
[docs] def get_band_pass_filter(num_taps, sampling_frequency, f1, f2): """Band pass filter Args: num_taps (int): number of taps in filter. sampling_frequency (float): sample frequency in Hz. f1 (float): cutoff frequency in Hz of left band edge. f2 (float): cutoff frequency in Hz of right band edge. Returns: ndarray: band pass filter """ bpf = scipy.signal.firwin(num_taps, [f1, f2], pass_zero=False, fs=sampling_frequency) return bpf
[docs] def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw): """Design complex low-pass filter. The filter is a low-pass FIR filter modulated to the center frequency. Args: num_taps (int): number of taps in filter. sampling_frequency (float): sample frequency. f (float): center frequency. bw (float): bandwidth in Hz. Raises: ValueError: if cutoff frequency (bw / 2) is not within (0, sampling_frequency / 2) Returns: ndarray: Complex-valued low-pass filter """ cutoff = bw / 2 if not (0 < cutoff < sampling_frequency / 2): raise ValueError( f"Cutoff frequency must be within (0, sampling_frequency / 2), " f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz" ) # Design real-valued low-pass filter lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency) # Modulate to center frequency to make it complex time_points = np.arange(num_taps) / sampling_frequency lpf_complex = lpf * np.exp(1j * 2 * np.pi * f * time_points) return lpf_complex
[docs] def complex_to_channels(complex_data, axis=-1): """Unroll complex data to separate channels. Args: complex_data (complex ndarray): complex input data. axis (int, optional): on which axis to extend. Defaults to -1. Returns: ndarray: real array with real and imaginary components unrolled over two channels at axis. """ # assert ops.iscomplex(complex_data).any() q_data = ops.imag(complex_data) i_data = ops.real(complex_data) i_data = ops.expand_dims(i_data, axis=axis) q_data = ops.expand_dims(q_data, axis=axis) iq_data = ops.concatenate((i_data, q_data), axis=axis) return iq_data
[docs] def channels_to_complex(data): """Convert array with real and imaginary components at different channels to complex data array. Args: data (ndarray): input data, with at 0 index of axis real component and 1 index of axis the imaginary. Returns: ndarray: complex array with real and imaginary components. """ assert data.shape[-1] == 2, "Data must have two channels." data = ops.cast(data, "complex64") return data[..., 0] + 1j * data[..., 1]
[docs] def hilbert(x, N: int = None, axis=-1): """Manual implementation of the Hilbert transform function. The function returns the analytical signal. Operated in the Fourier domain. Note: THIS IS NOT THE MATHEMATICAL THE HILBERT TRANSFORM as you will find it on wikipedia, but computes the analytical signal. The implementation reproduces the behavior of the `scipy.signal.hilbert` function. Args: x (ndarray): input data of any shape. N (int, optional): number of points in the FFT. Defaults to None. axis (int, optional): axis to operate on. Defaults to -1. Returns: x (ndarray): complex iq data of any shape.k """ input_shape = x.shape n_dim = len(input_shape) n_ax = input_shape[axis] if axis < 0: axis = n_dim + axis if N is not None: if N < n_ax: raise ValueError("N must be greater or equal to n_ax.") # only pad along the axis, use manual padding pad = N - n_ax zeros = ops.zeros( input_shape[:axis] + (pad,) + input_shape[axis + 1 :], ) x = ops.concatenate((x, zeros), axis=axis) else: N = n_ax # Create filter to zero out negative frequencies h = np.zeros(N) if N % 2 == 0: h[0] = h[N // 2] = 1 h[1 : N // 2] = 2 else: h[0] = 1 h[1 : (N + 1) // 2] = 2 idx = list(range(n_dim)) # make sure axis gets to the end for fft (operates on last axis) idx.remove(axis) idx.append(axis) x = ops.transpose(x, idx) if x.ndim > 1: ind = [np.newaxis] * x.ndim ind[-1] = slice(None) h = h[tuple(ind)] h = ops.convert_to_tensor(h) h = ops.cast(h, "complex64") h = h + 1j * ops.zeros_like(h) Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x))) Xf_r = ops.cast(Xf_r, "complex64") Xf_i = ops.cast(Xf_i, "complex64") Xf = Xf_r + 1j * Xf_i Xf = Xf * h # x = np.fft.ifft(Xf) # do manual ifft using fft Xf_r = ops.real(Xf) Xf_i = ops.imag(Xf) Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i)) Xf_i_inv = ops.cast(Xf_i_inv, "complex64") Xf_r_inv = ops.cast(Xf_r_inv, "complex64") x = Xf_r_inv / N x = x + 1j * (-Xf_i_inv / N) # switch back to original shape idx = list(range(n_dim)) idx.insert(axis, idx.pop(-1)) x = ops.transpose(x, idx) return x
[docs] def demodulate(data, center_frequency, sampling_frequency, axis=-3): """Demodulates the input data to baseband. The function computes the analytical signal (the signal with negative frequencies removed) and then shifts the spectrum of the signal to baseband by multiplying with a complex exponential. Where the spectrum was centered around `center_frequency` before, it is now centered around 0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts are stored in two real-valued channels. Args: data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`. center_frequency (float): The center frequency of the signal. sampling_frequency (float): The sampling frequency of the signal. axis (int, optional): The axis along which to demodulate. Defaults to -3. Returns: ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`. """ # Compute the analytical signal analytical_signal = hilbert(data, axis=axis) # Define frequency indices frequency_indices = ops.arange(analytical_signal.shape[axis]) # Expand the frequency indices to match the shape of the RF data indexing = [None] * data.ndim indexing[axis] = slice(None) indexing = tuple(indexing) frequency_indices_shaped_like_rf = frequency_indices[indexing] # Cast to complex64 center_frequency = ops.cast(center_frequency, dtype="complex64") sampling_frequency = ops.cast(sampling_frequency, dtype="complex64") frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64") # Shift to baseband phasor_exponent = ( -1j * 2 * np.pi * center_frequency * frequency_indices_shaped_like_rf / sampling_frequency ) iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent) # Split the complex signal into two channels iq_data_two_channel = complex_to_channels(iq_data_signal_complex[..., 0]) return iq_data_two_channel