Source code for zea.models.echonet

"""Echonet-Dynamic segmentation model for cardiac ultrasound segmentation.
Link below does not work it seems, this is slightly different but does have some info:
https://github.com/bryanhe/dynamic
"""

from pathlib import Path

import keras
import wget
from keras import backend, ops

from zea.backend import _import_tf
from zea.internal.registry import model_registry
from zea.models.base import BaseModel
from zea.models.preset_utils import get_preset_loader, register_presets
from zea.models.presets import echonet_dynamic_presets

INFERENCE_SIZE = 112

SEGMENTATION_WEIGHTS_URL = (
    "https://github.com/douyang/EchoNetDynamic/releases"
    "/download/v1.0.0/deeplabv3_resnet50_random.pt"
)
EJECTION_FRACTION_WEIGHTS_URL = (
    "https://github.com/douyang/EchoNetDynamic/releases"
    "/download/v1.0.0/r2plus1d_18_32_2_pretrained.pt"
)

tf = _import_tf()


[docs] @model_registry(name="echonet-dynamic") class EchoNetDynamic(BaseModel): """EchoNet-Dynamic segmentation model for cardiac ultrasound segmentation. Original paper and code: https://echonet.github.io/dynamic/ This class extracts useful parts of the original code and wraps it in a easy to use class. Preprocessing should normalize the input images with mean and standard deviation. """ def __init__(self, **kwargs): if backend.backend() not in ["tensorflow", "jax"]: raise NotImplementedError( "EchoNetDynamic is only currently supported with the TensorFlow or Jax backend." ) assert tf is not None, ( "TensorFlow is not installed. Please install TensorFlow to use EchoNetDynamic. This is " "required even if you are using the Jax backend, the model is built using TensorFlow." ) super().__init__(**kwargs) self.download_files = [ "variables/variables.data-00000-of-00001", "variables/variables.index", "saved_model.pb", "fingerprint.pb", ] self.network = None
[docs] def build(self, input_shape): """Builds the network.""" self.maybe_convert_to_jax(input_shape)
[docs] def maybe_convert_to_jax(self, input_shape): """Converts the network to Jax if backend is Jax.""" if backend.backend() == "jax": inputs = ops.zeros(input_shape) from zea.backend import tf2jax jax_func, jax_params = tf2jax.convert(tf.function(self.network), inputs) def call_fn(params, state, rng, inputs, training): return jax_func(state, inputs) self.network = keras.layers.JaxLayer(call_fn, state=jax_params)
[docs] def call(self, inputs): """Segment the input image.""" if self.network is None: raise ValueError( "Please load model using `EchoNetDynamic.from_preset()` before calling." ) assert inputs.ndim == 4, ( f"Input should have 4 dimensions (B, H, W, C), but has {inputs.ndim}." ) assert inputs.shape[-1] == 1 or inputs.shape[-1] == 3, ( f"Input should have 1 or 3 channels, but has {inputs.shape[-1]}." ) # resize image to 112x112 original_size = ops.shape(inputs)[1:3] inputs = ops.image.resize(inputs, [INFERENCE_SIZE, INFERENCE_SIZE]) if inputs.shape[-1] != 3: inputs = ops.tile(inputs, [1, 1, 1, 3]) if backend.backend() == "tensorflow": output = self.network(inputs)["segmentation"] elif backend.backend() == "jax": output = self.network(inputs) else: raise NotImplementedError( f"{self.__class__.__name__} is only currently supported with the " f"TensorFlow or Jax backend. You are using {backend.backend()}." ) # resize output to original size output = ops.image.resize(output, original_size) return output
def _load_layer(self, path: Path | str): if backend.backend() == "tensorflow": return keras.layers.TFSMLayer(path, call_endpoint="serving_default") elif backend.backend() == "jax": return tf.saved_model.load(path) else: raise NotImplementedError( f"{self.__class__.__name__} is only currently supported with the " f"TensorFlow or Jax backend. You are using {backend.backend()}." )
[docs] def custom_load_weights(self, preset, **kwargs): """Load the weights for the segmentation model.""" loader = get_preset_loader(preset) for file in self.download_files: filename = loader.get_file(file) base_path = Path(filename).parent self.network = self._load_layer(base_path)
def _download_original_weights(self, weights_folder=None): """Download the originals weights from the EchoNet Github repository.""" if weights_folder is None: weights_folder = "./echonet_weights" weights_folder = Path(weights_folder) url = SEGMENTATION_WEIGHTS_URL if not Path(weights_folder).exists(): print(f"Creating folder at {weights_folder} to store weights") Path(weights_folder).mkdir() assert weights_folder.is_dir(), ( f"weights_folder {weights_folder} is not a directory. " "Please specify the path to the folder containing the weights" ) file_path = weights_folder / Path(url).name if not file_path.is_file(): print( "Downloading Segmentation Weights, ", url, " to ", file_path, ) filename = wget.download(url, out=str(weights_folder)) assert Path(filename).name == Path(url).name, ( f"Downloaded file {Path(filename).name} does not match expected filename " f"{Path(url).name}" ) assert len(list(weights_folder.glob("*.pt"))) != 0, ( f"No .pt files found in {weights_folder}. " "Please make sure the correct weights are downloaded." ) else: print(f"EchoNet weights found in {file_path}") return file_path
register_presets(echonet_dynamic_presets, EchoNetDynamic)