Source code for zea.tools.fit_scan_cone

"""
Identifies the scan cone of an ultrasound image and crops it such that the apex of the cone
is at the top center of the image. In this form, it can be scan converted to polar coordinates.

This module provides functionality to:
1. Detect the scan cone boundaries in ultrasound images
2. Fit lines to the cone edges
3. Calculate cone parameters (apex position, opening angle, etc.)
4. Crop and center the image around the cone
5. Visualize the detected cone and its parameters
"""

import argparse
import os
from pathlib import Path

if __name__ == "__main__":
    os.environ["KERAS_BACKEND"] = (
        "numpy"  # recommend using numpy for this since some line fitting is performed on CPU
    )

import keras
import matplotlib.pyplot as plt
import numpy as np
from keras import ops

from zea import log


[docs] def filter_edge_points_by_boundary(edge_points, is_left=True, min_cone_half_angle_deg=20): """ Filter edge points to keep only those that form the actual boundary of the scan cone. Enforces minimum cone angle constraint to ensure valid cone shapes. Args: edge_points: Tensor of shape (N, 2) containing (x, y) coordinates of edge points is_left: Boolean indicating whether these are left (True) or right (False) edge points min_cone_half_angle_deg: Minimum expected half-angle of the cone in degrees Returns: Tensor of shape (M, 2) containing filtered edge points that satisfy the boundary constraints """ if ops.shape(edge_points)[0] == 0: return edge_points # Convert minimum angle to slope min_slope = ops.tan(np.radians(min_cone_half_angle_deg)) # Sort by y coordinate (top to bottom) sorted_indices = ops.argsort(edge_points[:, 1]) sorted_points = ops.take(edge_points, sorted_indices, axis=0) filtered_points = [] # Convert to numpy for the iterative logic (this part is hard to vectorize) sorted_points_np = ops.convert_to_numpy(sorted_points) for i, point in enumerate(sorted_points_np): x, y = point is_boundary_point = True # Check all points above this one for j in range(i): above_x, above_y = sorted_points_np[j] dy = y - above_y min_dx_required = min_slope * dy if is_left: # For left boundary if above_x < x or above_x - x < min_dx_required: is_boundary_point = False break else: # For right boundary if above_x > x or x - above_x < min_dx_required: is_boundary_point = False break if is_boundary_point: filtered_points.append(point) return ops.convert_to_tensor(filtered_points) if filtered_points else ops.zeros((0, 2))
[docs] def detect_cone_parameters(image, min_cone_half_angle_deg=20, threshold=15): """ Detect the ultrasound cone parameters from a grayscale image. This function performs the following steps: 1. Thresholds the image to create a binary mask 2. Detects left and right edge points of the cone 3. Filters edge points to ensure physical constraints 4. Fits lines to the cone boundaries 5. Calculates cone parameters including apex position, opening angle, and crop boundaries Args: image: 2D Keras tensor (grayscale image) min_cone_half_angle_deg: Minimum expected half-angle of the cone in degrees threshold: Threshold for binary image (pixels above this are considered data) Returns: Dictionary containing cone parameters including: - apex_x, apex_y: Coordinates of the cone apex - crop_left, crop_right, crop_top, crop_bottom: Crop boundaries - cone_height: Height of the cone - opening_angle: Opening angle of the cone - symmetry_ratio: Measure of cone symmetry - data_coverage: Fraction of crop region containing data - And other geometric parameters Raises: ValueError: If input image is not 2D or cone detection fails """ try: import cv2 except ImportError as exc: raise ImportError( "OpenCV is required for cone detection and visualization. " "Please install it with 'pip install opencv-python' or " "'pip install opencv-python-headless'." ) from exc if len(ops.shape(image)) != 2: raise ValueError("Input image must be 2D (grayscale)") # Ensure image is in proper range for cv2.threshold if image.dtype != "uint8": image = ops.cast(image * 255, "uint8") h, w = ops.shape(image) # OpenCV threshold requires numpy array image_np = ops.convert_to_numpy(image) _, thresh_np = cv2.threshold(image_np, threshold, 255, cv2.THRESH_BINARY) thresh = ops.convert_to_tensor(thresh_np) # Find non-zero pixel bounds non_zero_indices = ops.convert_to_tensor(ops.where(thresh > 0)) if ops.shape(non_zero_indices)[0] == 0: return None min_y = ops.min(non_zero_indices[:, 0]) max_y = ops.max(non_zero_indices[:, 0]) # Collect left and right edge points left_edge_points = [] right_edge_points = [] # Convert back to numpy for row-wise processing (this is hard to vectorize efficiently) thresh_np = ops.convert_to_numpy(thresh) min_y_np = int(ops.convert_to_numpy(min_y)) max_y_np = int(ops.convert_to_numpy(max_y)) for y in range(min_y_np, max_y_np + 1): row = thresh_np[y, :] non_zero_x = np.where(row > 0)[0] if len(non_zero_x) > 0: width = non_zero_x[-1] - non_zero_x[0] if width > 5: # Minimum width to avoid noise left_edge_points.append([non_zero_x[0], y]) right_edge_points.append([non_zero_x[-1], y]) if len(left_edge_points) < 10: return None left_edge_points = ops.convert_to_tensor(left_edge_points, dtype="float32") right_edge_points = ops.convert_to_tensor(right_edge_points, dtype="float32") # Filter edge points to keep only boundary points filtered_left_points = filter_edge_points_by_boundary( left_edge_points, is_left=True, min_cone_half_angle_deg=min_cone_half_angle_deg ) filtered_right_points = filter_edge_points_by_boundary( right_edge_points, is_left=False, min_cone_half_angle_deg=min_cone_half_angle_deg, ) if ops.shape(filtered_left_points)[0] < 3 or ops.shape(filtered_right_points)[0] < 3: return None # Fit lines using least squares: x = a + b*y # Convert to numpy for lstsq filtered_left_np = ops.convert_to_numpy(filtered_left_points) filtered_right_np = ops.convert_to_numpy(filtered_right_points) # Left line A_left = np.vstack([np.ones(len(filtered_left_np)), filtered_left_np[:, 1]]).T left_coeffs, _, _, _ = np.linalg.lstsq(A_left, filtered_left_np[:, 0], rcond=None) left_a, left_b = left_coeffs # Right line A_right = np.vstack([np.ones(len(filtered_right_np)), filtered_right_np[:, 1]]).T right_coeffs, _, _, _ = np.linalg.lstsq(A_right, filtered_right_np[:, 0], rcond=None) right_a, right_b = right_coeffs # Convert back to tensors left_a = ops.convert_to_tensor(left_a) left_b = ops.convert_to_tensor(left_b) right_a = ops.convert_to_tensor(right_a) right_b = ops.convert_to_tensor(right_b) # Calculate apex as intersection of fitted lines if ops.abs(left_b - right_b) < 1e-6: # Lines are parallel return None apex_y = (right_a - left_a) / (left_b - right_b) apex_x = left_a + left_b * apex_y # Calculate cone height cone_height = max_y - apex_y # Calculate opening angle from the line slopes # Convert slopes to angles and calculate opening angle left_angle = ops.arctan(left_b) # angle of left line from horizontal right_angle = ops.arctan(right_b) # angle of right line from horizontal opening_angle = ops.abs(left_angle - right_angle) min_non_zero_pixel_idx = (0, 0) for i in reversed(range(0, thresh_np.shape[0])): row = thresh_np[i] non_zero_pixel_col = np.where(row > 0)[0] if np.any(non_zero_pixel_col): min_non_zero_pixel_idx = (i, non_zero_pixel_col[0]) break circle_radius = float( np.sqrt( (min_non_zero_pixel_idx[1] - ops.convert_to_numpy(apex_x)) ** 2 + (min_non_zero_pixel_idx[0] - ops.convert_to_numpy(apex_y)) ** 2 ) ) circle_center_x = float(ops.convert_to_numpy(apex_x)) circle_center_y = float(ops.convert_to_numpy(apex_y)) # Calculate where the circle intersects with the cone lines # For line: x = a + b*y and circle: (x - cx)^2 + (y - cy)^2 = r^2 def line_circle_intersection(a, b, cx, cy, r): """Find intersection of line x = a + b*y with circle centered at (cx, cy) with radius r""" # Substitute line equation into circle equation # (a + b*y - cx)^2 + (y - cy)^2 = r^2 # Expand and collect terms to get quadratic in y A = 1 + b**2 B = 2 * b * (a - cx) - 2 * cy C = (a - cx) ** 2 + cy**2 - r**2 discriminant = B**2 - 4 * A * C if discriminant < 0: return None # Two solutions y1 = (-B + np.sqrt(discriminant)) / (2 * A) y2 = (-B - np.sqrt(discriminant)) / (2 * A) # Return the solution closest to the bottom (larger y) if y1 > y2: return y1, a + b * y1 else: return y2, a + b * y2 # Find intersection points left_intersect = line_circle_intersection( ops.convert_to_numpy(left_a), ops.convert_to_numpy(left_b), circle_center_x, circle_center_y, circle_radius, ) right_intersect = line_circle_intersection( ops.convert_to_numpy(right_a), ops.convert_to_numpy(right_b), circle_center_x, circle_center_y, circle_radius, ) if left_intersect is None or right_intersect is None: # Fallback to line endpoints at max_y left_y_bottom, left_x_bottom = ( ops.convert_to_numpy(max_y), ops.convert_to_numpy(left_a + left_b * max_y), ) right_y_bottom, right_x_bottom = ( ops.convert_to_numpy(max_y), ops.convert_to_numpy(right_a + right_b * max_y), ) else: left_y_bottom, left_x_bottom = left_intersect right_y_bottom, right_x_bottom = right_intersect sector_bottom = int(circle_radius + apex_y) # this is the coordinate on the unpadded image # Calculate crop boundaries (can be negative) padding_x = 0 padding_y = 0 crop_left = int(left_x_bottom) - padding_x crop_right = int(right_x_bottom) + padding_x crop_top = int(ops.convert_to_numpy(apex_y)) - padding_y crop_bottom = int(sector_bottom) # Calculate final dimensions new_width = crop_right - crop_left new_height = crop_bottom - crop_top # Store the intersection points for visualization sector_left_x = left_x_bottom sector_left_y = left_y_bottom sector_right_x = right_x_bottom sector_right_y = right_y_bottom # Calculate symmetry ratio (how symmetric the cone is) symmetry_ratio = float( ops.convert_to_numpy( ops.abs(left_b + right_b) / (ops.abs(left_b) + ops.abs(right_b) + 1e-8) ) ) # Calculate data coverage in the crop region h_np = int(ops.convert_to_numpy(h)) w_np = int(ops.convert_to_numpy(w)) crop_left_clipped = max(0, crop_left) crop_right_clipped = min(w_np, crop_right) crop_top_clipped = max(0, crop_top) crop_bottom_clipped = min(h_np, crop_bottom) data_coverage = 0.0 assert crop_left_clipped < crop_right_clipped and crop_top_clipped < crop_bottom_clipped crop_region = thresh_np[ crop_top_clipped:crop_bottom_clipped, crop_left_clipped:crop_right_clipped ] data_coverage = float(np.sum(crop_region > 0) / crop_region.size) return { "apex_x": float(ops.convert_to_numpy(apex_x)), "apex_y": float(ops.convert_to_numpy(apex_y)), "crop_left": crop_left, "crop_right": crop_right, "crop_top": crop_top, "crop_bottom": crop_bottom, "original_width": int(ops.convert_to_numpy(w)), "original_height": int(ops.convert_to_numpy(h)), # Additional parameters for debug and analysis "cone_height": float(ops.convert_to_numpy(cone_height)), "opening_angle": float(ops.convert_to_numpy(opening_angle)), "new_width": new_width, "new_height": new_height, "symmetry_ratio": symmetry_ratio, "first_data_row": int(ops.convert_to_numpy(min_y)), "data_coverage": data_coverage, "apex_above_image": bool(ops.convert_to_numpy(apex_y) < 0), # Line parameters for reconstruction if needed "left_slope": float(ops.convert_to_numpy(left_b)), "right_slope": float(ops.convert_to_numpy(right_b)), "left_intercept": float(ops.convert_to_numpy(left_a)), "right_intercept": float(ops.convert_to_numpy(right_a)), # Circle parameters for the bottom boundary "circle_center_x": circle_center_x, "circle_center_y": circle_center_y, "circle_radius": circle_radius, # Sector intersection points for visualization "sector_left_x": sector_left_x, "sector_left_y": sector_left_y, "sector_right_x": sector_right_x, "sector_right_y": sector_right_y, "sector_bottom": sector_bottom, }
[docs] def crop_and_center_cone(image, cone_params): """ Crop the image to the sector bounding box and pad as needed to center the apex. This function: 1. Crops the image to the detected cone boundaries 2. Adds padding if the apex is above the image 3. Centers the apex horizontally in the final image Args: image: 2D Keras tensor (grayscale image) cone_params: Dictionary of cone parameters from detect_cone_parameters() Returns: Keras tensor of the cropped and centered image with the cone apex at the top center """ # Get crop boundaries crop_left = cone_params["crop_left"] crop_right = cone_params["crop_right"] crop_top = cone_params["crop_top"] crop_bottom = cone_params["crop_bottom"] # Crop the image (handle negative crop_top) if crop_top < 0: cropped = image[0:crop_bottom, crop_left:crop_right] # Add top padding top_padding = -crop_top cropped_width = ops.shape(cropped)[1] top_pad = ops.zeros((top_padding, cropped_width), dtype=cropped.dtype) cropped = ops.concatenate([top_pad, cropped], axis=0) else: cropped = image[crop_top:crop_bottom, crop_left:crop_right] # Now handle horizontal centering # Calculate where the apex is in the cropped image apex_x_in_crop = cone_params["apex_x"] - crop_left cropped_height = ops.shape(cropped)[0] cropped_width = ops.shape(cropped)[1] # Calculate the target center position target_center_x = ops.cast(cropped_width / 2, "float32") # Calculate how much padding we need on each side # We want: left_padding + apex_x_in_crop = final_width / 2 # And: final_width = cropped_width + left_padding + right_padding # For symmetric padding: left_padding = right_padding # So: left_padding + apex_x_in_crop = (cropped_width + 2*left_padding) / 2 # Solving: left_padding = cropped_width/2 - apex_x_in_crop left_padding_needed = target_center_x - apex_x_in_crop # Ensure we have non-negative padding left_padding = ops.maximum(0, ops.cast(left_padding_needed, "int32")) right_padding = ops.maximum(0, ops.cast(-left_padding_needed, "int32")) # Apply horizontal padding if needed if left_padding > 0 or right_padding > 0: if left_padding > 0: left_pad = ops.zeros((cropped_height, left_padding), dtype=cropped.dtype) cropped = ops.concatenate([left_pad, cropped], axis=1) if right_padding > 0: right_pad = ops.zeros((cropped_height, right_padding), dtype=cropped.dtype) cropped = ops.concatenate([cropped, right_pad], axis=1) return cropped
[docs] def fit_and_crop_around_scan_cone( image_tensor, min_cone_half_angle_deg=20, threshold=15, return_params=False ): """ Detect scan cone in ultrasound image and return cropped/padded image with centered apex. Args: image_tensor: Keras tensor (2D grayscale image) min_cone_half_angle_deg: Minimum expected half-angle of the cone in degrees (default: 20) threshold: Threshold for binary image - pixels above this are considered data (default: 15) return_params: If True, also return cone parameters (default: False) Returns: - If return_params is False: Keras tensor (cropped and padded image with apex at center) - If return_params is True: Tuple of (cropped_tensor, cone_parameters_dict) Raises: ValueError: If cone detection fails or image is not 2D """ if keras.backend.backend() != "numpy": log.info(f"❗️ It is recommended to use {log.blue('numpy')} backend for `fit_scan_cone()`.") # Ensure image is 2D if len(ops.shape(image_tensor)) != 2: raise ValueError(f"Input must be 2D grayscale image, got shape {ops.shape(image_tensor)}") # Detect cone parameters cone_params = detect_cone_parameters( image_tensor, min_cone_half_angle_deg=min_cone_half_angle_deg, threshold=threshold, ) if cone_params is None: raise ValueError("Failed to detect ultrasound cone in image") # Crop and center the image cropped_image = crop_and_center_cone(image_tensor, cone_params) if return_params: return cropped_image, cone_params else: return cropped_image
[docs] def visualize_scan_cone(image, cone_params, output_dir="output"): """ Create visualization plots for the scan cone detection. Args: image: Original grayscale image cone_params: Dictionary of cone parameters from detect_cone_parameters() output_dir: Directory to save output plots (default: "output") The visualization is saved as 'scan_cone_visualization.png' in the output directory. """ # Create output directory output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Create figure _, ax = plt.subplots(1, 1, figsize=(12, 10)) ax.imshow(image, cmap="gray" if len(image.shape) == 2 else None) # Extract parameters apex_x, apex_y = cone_params["apex_x"], cone_params["apex_y"] left_a = cone_params["left_intercept"] left_b = cone_params["left_slope"] right_a = cone_params["right_intercept"] right_b = cone_params["right_slope"] # Get bottom intersection points bottom_y = image.shape[0] - 1 left_x_bottom = left_a + left_b * bottom_y right_x_bottom = right_a + right_b * bottom_y # Draw cone lines ax.plot( [apex_x, left_x_bottom], [apex_y, bottom_y], color="#FF0000", # Bright red linewidth=2, label="Cone boundary", ) ax.plot( [apex_x, right_x_bottom], [apex_y, bottom_y], color="#FF0000", # Bright red linewidth=2, ) # Draw circular arc between the fitted lines if "circle_center_x" in cone_params: circle_center_x = cone_params["circle_center_x"] circle_center_y = cone_params["circle_center_y"] circle_radius = cone_params["circle_radius"] # Calculate angles for left and right intersection points left_angle = np.arctan2(left_x_bottom - circle_center_x, bottom_y - circle_center_y) right_angle = np.arctan2(right_x_bottom - circle_center_x, bottom_y - circle_center_y) # Ensure angles are in the correct order if right_angle < left_angle: right_angle += 2 * np.pi # Create arc points theta = np.linspace(left_angle, right_angle, 100) arc_x = circle_center_x + circle_radius * np.sin(theta) arc_y = circle_center_y + circle_radius * np.cos(theta) # Only draw the part that's visible in the image visible_mask = ( (arc_y >= 0) & (arc_y < image.shape[0]) & (arc_x >= 0) & (arc_x < image.shape[1]) ) if np.any(visible_mask): ax.plot( arc_x[visible_mask], arc_y[visible_mask], color="#00FF00", # Bright green linestyle="--", linewidth=2, label="Sector arc", ) # Draw apex as a star ax.plot( apex_x, apex_y, marker="*", # Star marker markersize=15, color="#FFD700", # Gold # markeredgecolor="white", markeredgewidth=2, label="Cone apex", ) # Draw crop rectangle rect = plt.Rectangle( (cone_params["crop_left"], cone_params["crop_top"]), cone_params["crop_right"] - cone_params["crop_left"], cone_params["crop_bottom"] - cone_params["crop_top"], fill=False, color="#00FFFF", # Cyan linewidth=2, label="Crop region", linestyle="dotted", ) ax.add_patch(rect) # Add info text opening_angle_deg = np.degrees(cone_params["opening_angle"]) info_text = ( f"Opening Angle: {opening_angle_deg:.1f}°\n" f"Symmetry Ratio: {cone_params.get('symmetry_ratio', 0.0):.3f}" ) if "circle_radius" in cone_params: info_text += f"\nCircle Radius: {cone_params['circle_radius']:.1f} px" ax.text( 0.02, 0.98, info_text, transform=ax.transAxes, fontsize=11, va="top", bbox={ "boxstyle": "round", "facecolor": "#FFFF00", "alpha": 0.9, }, # Bright yellow ) # Add crop info crop_info = ( f"Crop: ({cone_params['crop_left']}, {cone_params['crop_top']}) to " f"({cone_params['crop_right']}, {cone_params['crop_bottom']})\n" f"Final Size: {cone_params['new_width']} x {cone_params['new_height']}" ) ax.text( 0.02, 0.02, crop_info, transform=ax.transAxes, fontsize=11, va="bottom", bbox={ "boxstyle": "round", "facecolor": "#00FFFF", "alpha": 0.9, }, ) # Set up the plot ax.set_xlim(0, image.shape[1]) ax.set_ylim(image.shape[0], 0) ax.legend(loc="upper right") ax.set_title("Scan Cone Detection", fontsize=16) ax.axis("off") plt.tight_layout() plt.savefig( output_path / "scan_cone_visualization.png", dpi=300, bbox_inches="tight", pad_inches=0.5, ) log.info( f"Saved cone visualization to {log.yellow(output_path / 'scan_cone_visualization.png')}" ) plt.close()
[docs] def main(avi_path): """Demonstrate scan cone fitting on a sample AVI file.""" try: import cv2 except ImportError as exc: raise ImportError( "OpenCV is required for cone detection and visualization. " "Please install it with 'pip install opencv-python' or " "'pip install opencv-python-headless'." ) from exc # Load first frame cap = cv2.VideoCapture(avi_path) ret, frame = cap.read() cap.release() if not ret: print(f"Failed to read video file: {avi_path}") return # Convert to grayscale frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # Convert to tensor frame_tensor = ops.convert_to_tensor(frame) try: # Fit scan cone _, cone_params = fit_and_crop_around_scan_cone( frame_tensor, min_cone_half_angle_deg=20, threshold=15, return_params=True ) # Create visualization visualize_scan_cone(frame, cone_params) print("Visualization saved to output/scan_cone_visualization.png") except ValueError as e: print(f"Error fitting scan cone: {e}")
[docs] def get_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Precompute cone parameters for EchoNet-LVH dataset" ) parser.add_argument( "--input_file", type=str, required=True, ) return parser.parse_args()
if __name__ == "__main__": args = get_args() input_file = args.input_file main(input_file)