Using zea.Models: a UNet example

In this notebook, we demonstrate how to use the zea.Models interface with a popular deep learning architecture: the UNet. We’ll use a pretrained UNet to inpaint missing regions in ultrasound images, and visualize the results. This workflow can be adapted for other tasks and models in the zea toolbox.

Open In Colab   View on GitHub   Hugging Face model

[ ]:
%%capture
%pip install zea
[2]:
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
[3]:
import matplotlib.pyplot as plt
from keras import ops

from zea import init_device, log
from zea.backend.tensorflow.dataloader import make_dataloader
from zea.models.unet import UNet
from zea.models.lpips import LPIPS
from zea.agent.masks import random_uniform_lines
from zea.visualize import plot_image_grid, set_mpl_style
zea: Using backend 'jax'

We will work with the GPU if available, and initialize using init_device to pick the best available device. Also, (optionally), we will set the matplotlib style for plotting.

[4]:
device = init_device(verbose=False)
set_mpl_style()

Load Data

We load a small batch from the CAMUS validation dataset hosted on Hugging Face Hub.

[5]:
n_imgs = 8

val_dataset = make_dataloader(
    "hf://zeahub/camus-sample/val",
    key="data/image",
    batch_size=n_imgs,
    shuffle=True,
    image_size=[128, 128],
    resize_type="resize",
    image_range=[-60, 0],
    normalization_range=[-1, 1],
    seed=42,
)
batch = next(iter(val_dataset))
batch = ops.clip(batch, -1, 1)
zea: Using pregenerated dataset info file: /home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/dataset_info.yaml ...
zea: ...for reading file paths in /home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val
zea: Dataset was validated on June 15, 2025
zea: Remove /home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/validated.flag if you want to redo validation.
zea: WARNING H5Generator: Not all files have the same shape. This can lead to issues when resizing images later....
zea: H5Generator: Shuffled data.
zea: H5Generator: Shuffled data.

Load UNet Model

We use a pretrained UNet model from zea for inpainting.

[6]:
presets = list(UNet.presets.keys())
log.info(f"Available built-in zea presets for UNet: {presets}")

model = UNet.from_preset("unet-echonet-inpainter")
zea: Available built-in zea presets for UNet: ['unet-echonet-inpainter']

Simulate Missing Data

We simulate missing data by masking out random columns in each image (e.g., 75% missing). This is a common scenario in cognitive ultrasound where some scanlines may be missing (i.e. not acquired) or corrupted.

[7]:
n_columns = 128  # batch.shape[2]
mask = random_uniform_lines(n_columns // 4, n_columns, n_imgs)
corrupted = batch * ops.cast(mask[:, None, :, None], batch.dtype)

Inpaint with UNet

We use the UNet to inpaint the missing regions.

[8]:
inpainted = model(corrupted)
inpainted = ops.clip(inpainted, -1, 1)

Evaluate Perceptual Similarity

We use the LPIPS metric to evaluate perceptual similarity between the ground truth and inpainted images. For more detailed example of this metric, see this notebook.

[9]:
lpips = LPIPS.from_preset("lpips")
lpips_scores = lpips([inpainted, inpainted])
lpips_scores = ops.convert_to_numpy(lpips_scores)

Visualization

We plot the ground truth, corrupted, inpainted, and error images. The LPIPS score is shown on each inpainted image. Note that this model was trained on the EchoNet-Dynamic dataset, whereas we are testing now on the CAMUS dataset.

[10]:
error = ops.abs(batch - inpainted)
imgs = ops.concatenate([batch, corrupted, inpainted, error], axis=0)
imgs = ops.convert_to_numpy(imgs)

cmaps = ["gray"] * (3 * n_imgs) + ["viridis"] * n_imgs

fig, _ = plot_image_grid(
    imgs,
    vmin=-1,
    vmax=1,
    ncols=n_imgs,
    remove_axis=False,
    cmap=cmaps,
    figsize=(n_imgs * 2, 6),
)

titles = ["Ground Truth", "Corrupted", "Inpainted", "Error"]
for i, ax in enumerate(fig.axes[: len(titles) * n_imgs]):
    if i % n_imgs == 0:
        ax.set_ylabel(titles[i // n_imgs])

# Show LPIPS score on each inpainted image
for ax, lpips_score in zip(fig.axes[n_imgs * 2 : 3 * n_imgs], lpips_scores):
    ax.text(
        0.95,
        0.95,
        f"LPIPS: {float(lpips_score):.4f}",
        ha="right",
        va="top",
        transform=ax.transAxes,
        fontsize=8,
        color="yellow",
    )
plt.show()
../../_images/notebooks_models_unet_example_18_0.png

You can try other UNet presets or experiment with different masking strategies to explore the capabilities of zea.Models!