Source code for zea.interface

"""Convenience interface for loading and displaying ultrasound data.

Example usage
^^^^^^^^^^^^^^

.. code-block:: python

    import zea
    from zea.internal.setup_zea import setup_config

    config = setup_config("hf://zeahub/configs/config_camus.yaml")

    interface = zea.Interface(config)
    interface.run(plot=True)

"""

import asyncio
import time
from pathlib import Path
from typing import List

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from zea import log
from zea.config import Config
from zea.data.file import File
from zea.datapaths import format_data_path
from zea.display import to_8bit
from zea.internal.core import DataTypes
from zea.internal.viewer import (
    ImageViewerMatplotlib,
    ImageViewerOpenCV,
    filename_from_window_dialog,
    running_in_notebook,
)
from zea.io_lib import matplotlib_figure_to_numpy
from zea.ops import Pipeline
from zea.utils import keep_trying, save_to_gif, save_to_mp4


[docs] class Interface: """Interface for selecting / loading / processing single ultrasound images. Useful for inspecting datasets and single ultrasound images. # TODO: maybe we can refactor such that it is clear what needs to be in config. """ def __init__(self, config: Config = None, verbose: bool = True, validate_file: bool = True): """Initialize Interface. Args: config (Config): Configuration object. verbose (bool): Whether to print verbose output. validate_file (bool): Whether to validate the file. """ self.config = Config(config) self.verbose = verbose self.file = File(self.file_path) if validate_file: self.file.validate() # get probe and scan from file self.probe = self.file.probe() self.scan = self.file.scan(**self.config.scan) # initialize Pipeline assert "pipeline" in self.config, ( "Pipeline not found in config, please specify pipeline in config." ) self.process = Pipeline.from_config( self.config.pipeline, with_batch_dim=False, jit_options=None, ) self.parameters = self.process.prepare_parameters(self.probe, self.scan, self.config.scan) # initialize attributes for UI class self.data = None self.image = None self.mpl_img = None self.fig = None self.ax = None self.gui = None self.image_viewer = None self.plot_lib = self.config.plot.plot_lib if self.config.plot.headless is None: self.headless = False else: self.headless = self.config.plot.headless self.check_for_display() if self.plot_lib == "opencv": self.image_viewer = ImageViewerOpenCV( self.data_to_display, window_name=self.file.name, num_threads=1, headless=self.headless, ) elif self.plot_lib == "matplotlib": self.image_viewer = ImageViewerMatplotlib( self.data_to_display, window_name=self.file.name, num_threads=1, ) @property def dtype(self): """Data type of data when loaded from file.""" return self.config.data.dtype @property def dataset_folder(self): """Path to dataset folder.""" return format_data_path(self.config.data.dataset_folder, self.config.data.user) @property def file_path(self): """Path to data file.""" if self.config.data.file_path: return self.dataset_folder / self.config.data.file_path else: return self.choose_file_path() @file_path.setter def file_path(self, value): """Set file path to data file.""" self.config.data.file_path = value
[docs] def choose_file_path(self): """Choose file path from window dialog.""" if self.headless: raise ValueError( "No file path specified for data file, which is required " "in headless mode as window dialog cannot be opened." ) filetype = "hdf5" log.info("Please select file from window dialog...") self.file_path = filename_from_window_dialog( f"Choose .{filetype} file", filetypes=((filetype, "*." + filetype),), initialdir=self.dataset_folder, ) return self.file_path
@property def data_root(self): """Root path to data file.""" return Path(self.config.user.data_root) @dtype.setter def dtype(self, value): self.config.data.dtype = value @property def to_dtype(self): """Data type to convert to for display.""" return self.config.data.to_dtype @to_dtype.setter def to_dtype(self, value): self.config.data.to_dtype = value @property def frame_no(self): """Frame number to display.""" return self.config.data.get("frame_no") @frame_no.setter def frame_no(self, value): self.config.data.frame_no = value
[docs] def check_for_display(self): """check if in headless mode (no monitor available)""" if self.headless is False: if matplotlib.get_backend().lower() == "agg": self.headless = True log.warning("Could not connect to display, running headless.") else: # self.plot_lib = "matplotlib" # force matplotlib in headless mode matplotlib.use("agg") log.info("Running in headless mode as set by config.")
[docs] def set_backend_for_notebooks(self): """Set backend to QtAgg if running in notebook""" if running_in_notebook() and not self.headless: matplotlib.use("QtAgg")
[docs] def get_data(self): """Get data. Chosen datafile should be listed in the dataset. Using either file specified in config or if None, the ui window. Returns: data (np.ndarray): data array of shape (n_tx, n_el, n_ax, N_ch) """ if self.verbose: log.info(f"Selected {log.yellow(self.file_path)}") # grab frame number from config or user input if not set in config if self.frame_no == "all": log.info("Will run all frames as `all` was chosen in config...") elif self.frame_no is None: if self.file.n_frames == 1: self.frame_no = 0 else: self.frame_no = keep_trying( lambda: int(input(f">> Frame number (0 / {self.file.n_frames - 1}): ")) ) # get data from dataset data = self.file.load_data(self.dtype, self.frame_no) return data
[docs] def data_to_display(self, data=None): """Get data and convert to display to_dtype.""" if data is None: self.data = self.get_data() else: self.data = data if self.to_dtype not in ["image", "image_sc"]: log.warning( f"Image to_dtype: {self.to_dtype} not supported for displaying data." "falling back to to_dtype: `image_sc`" ) self.to_dtype = "image_sc" # select transmits if raw or aligned data data_type = self.process.operations[0].input_data_type if data_type in [DataTypes.RAW_DATA, DataTypes.ALIGNED_DATA]: n_tx = self.data.shape[0] assert len(self.scan.selected_transmits) <= n_tx, ( f"Number of selected transmits {len(self.scan.selected_transmits)} " f"exceeds number of transmits in raw data {n_tx}" ) self.data = np.take(self.data, self.scan.selected_transmits, axis=0) inputs = {self.process.key: self.data} outputs = self.process(**inputs, **self.parameters) self.image = outputs[self.process.output_key] # match orientation if necessary if self.config.plot.fliplr: self.image = np.fliplr(self.image) # opencv requires 8 bit images if self.plot_lib == "opencv": self.image = to_8bit(self.image, self.config.data.dynamic_range) return self.image
[docs] def run(self, plot=False, block=True): """Run ui. Will retrieve, process and plot data if set to True.""" save = self.config.plot.save if self.frame_no == "all": if not asyncio.get_event_loop().is_running(): asyncio.run(self.run_movie(save)) else: asyncio.create_task(self.run_movie(save)) else: if plot: self.image = self.plot( save=save, block=block, ) else: self.image = self.data_to_display() return self.image
[docs] def plot( self, data: np.ndarray = None, save: bool = False, block: bool = True, ): """Plot image using matplotlib or opencv. Args: save (bool): whether to save the image to disk. block (bool): whether to block the UI while plotting. Returns: image (np.ndarray): plotted image (grabbed from figure). """ assert self.image_viewer is not None, "Image viewer not initialized." self.image_viewer.threading = False if self.plot_lib == "matplotlib": if self.image_viewer.fig is None: self._init_plt_figure() self.image_viewer.show(data) if save: self.save_image(self.fig) if not self.headless and block: plt.show(block=True) self.image = matplotlib_figure_to_numpy(self.fig) return self.image elif self.plot_lib == "opencv": self.image_viewer.show(data) if not self.headless and block: self.image_viewer._cv2.waitKey(0) self.save_image(self.image) return self.image
def _init_plt_figure(self): figsize = (10, 10) if self.scan: extent = [ self.scan.xlims[0] * 1e3, self.scan.xlims[1] * 1e3, self.scan.zlims[1] * 1e3, self.scan.zlims[0] * 1e3, ] # set figure aspect ratio to match scan aspect_ratio = abs(extent[1] - extent[0]) / abs(extent[3] - extent[2]) figsize = tuple(np.array(figsize) * aspect_ratio) else: extent = None self.fig, self.ax = plt.subplots(figsize=figsize) image_range = self.config.data.dynamic_range imshow_kwargs = { "cmap": "gray", "vmin": image_range[0], "vmax": image_range[1], "origin": "upper", "extent": extent, "interpolation": "none", } cax_kwargs = { "pad": 0.05, "position": "right", "size": "5%", } self.ax.set_xlabel("Lateral Width (mm)", size=15) self.ax.set_ylabel("Axial length (mm)", size=15) self.ax.tick_params(axis="x") self.ax.tick_params(axis="y") # assign properties of fig, ax to image viewer self.image_viewer.imshow_kwargs = imshow_kwargs self.image_viewer.cax_kwargs = cax_kwargs self.image_viewer.fig = self.fig self.image_viewer.ax = self.ax
[docs] async def run_movie(self, save: bool = False): """Run all frames in file in sequence""" log.info('Playing video, press/hold "q" while the window is active to exit...') self.image_viewer.threading = True images = await self._movie_loop(save) if save: self.save_video(images)
async def _movie_loop(self, save: bool = False) -> List[np.ndarray]: """Process data and plot it in real time. NOTE: when plot loop is terminated by user, it will only save the shown frames. This is to prevent long waiting times when saving a movie (for large datasets). Args: save (bool): Whether to save the plotted images. Returns: list: A list of the plotted images. """ # Initialize list of images images = [] # Load correct number of frames (needs to get_data first) self.frame_no = 0 self.get_data() n_frames = self.file.n_frames self.verbose = False try: while True: # first frame is already plotted during initialization of plotting start_time = time.time() frame_counter = 0 self.image_viewer.frame_no = 0 while frame_counter < n_frames: if self.gui: await self.gui.check_freeze() await asyncio.sleep(0.01) self.frame_no = frame_counter if frame_counter == 0: if self.plot_lib == "matplotlib": if self.image_viewer.fig is None: self._init_plt_figure() self.image_viewer.show() # set counter to frame number of image viewer (possibly not updated) frame_counter = self.image_viewer.frame_no # check if frame counter updated if frame_counter != self.frame_no: fps = frame_counter / (time.time() - start_time) print( f"frame {frame_counter} / {n_frames} ({fps:.2f} fps)", end="\r", ) if save and (len(images) < n_frames): if self.plot_lib == "matplotlib": # grab image from plt figure image = matplotlib_figure_to_numpy(self.fig) else: image = np.array(self.image) images.append(image) # For opencv, show frame for 25 ms and check if "q" is pressed if not self.headless: if self.plot_lib == "opencv": if self.image_viewer._cv2.waitKey(25) & 0xFF == ord("q"): self.image_viewer.close() return images if self.image_viewer.has_been_closed(): return images # For matplotlib, check if window has been closed elif self.plot_lib == "matplotlib": if time.sleep(0.025) and self.image_viewer.has_been_closed(): return images # For headless mode, check if all frames have been plotted if self.headless: if len(images) == n_frames: return images # clear line, frame number print("\x1b[2K", end="\r") # only loop once if in headless mode if self.headless: return images except KeyboardInterrupt: if save: if len(images) > 0: self.save_video(images) raise
[docs] def save_image(self, fig, path=None): """Save image to disk. Args: fig (fig object): figure. path (str, optional): path to save image to. Defaults to None. """ if path is None: if self.config.plot.tag: tag = "_" + self.config.plot.tag else: tag = "" if self.frame_no is not None: filename = self.file_path.stem + "-" + str(self.frame_no) + tag else: filename = self.file_path.stem + tag ext = f".{self.config.plot.image_extension.lstrip('.')}" path = Path("./figures", filename).with_suffix(ext) Path("./figures").mkdir(parents=True, exist_ok=True) if isinstance(fig, plt.Figure): fig.savefig(path, transparent=True) elif isinstance(fig, Image.Image): fig.save(path) else: raise ValueError( f"Figure is not PIL image or matplotlib figure object, got {type(fig)}" ) if self.verbose: log.info(f"Image saved to {log.yellow(path)}")
[docs] def save_video(self, images, path=None): """Save video to disk. Args: images (list): list of images. path (str, optional): path to save image to. Defaults to None. """ if path is None: if self.config.plot.tag: tag = "_" + self.config.plot.tag else: tag = "" filename = self.file_path.stem + tag + "." + self.config.plot.video_extension path = Path("./figures", filename) Path("./figures").mkdir(parents=True, exist_ok=True) if not isinstance(images[0], np.ndarray): raise ValueError("Images are not numpy arrays.") fps = self.config.plot.fps if self.config.plot.video_extension == "gif": save_to_gif(images, path, fps=fps) elif self.config.plot.video_extension == "mp4": save_to_mp4(images, path, fps=fps) if self.verbose: log.info(f"Video saved to {log.yellow(path)}")
def __del__(self): try: if self.image_viewer is not None: self.image_viewer.close() except Exception: pass try: if self.fig is not None: plt.close(self.fig) except Exception: pass try: if self.file is not None: self.file.close() except Exception: pass