"""
H5 dataloader for loading images from zea datasets.
"""
import re
from itertools import product
from pathlib import Path
from typing import List
import numpy as np
from zea import log
from zea.data.datasets import Dataset, H5FileHandleCache, count_samples_per_directory
from zea.data.file import File
from zea.data.utils import json_dumps
from zea.io_lib import retry_on_io_error
from zea.utils import map_negative_indices
DEFAULT_NORMALIZATION_RANGE = (0, 1)
MAX_RETRY_ATTEMPTS = 3
INITIAL_RETRY_DELAY = 0.1
[docs]
def generate_h5_indices(
file_paths: List[str],
file_shapes: list,
n_frames: int,
frame_index_stride: int,
key: str = "data/image",
initial_frame_axis: int = 0,
additional_axes_iter: List[int] | None = None,
sort_files: bool = True,
overlapping_blocks: bool = False,
limit_n_frames: int | None = None,
):
"""Generate indices for h5 files.
Generates a list of indices to extract images from hdf5 files. Length of this list
is the length of the extracted dataset.
Args:
file_paths (list): List of file paths.
file_shapes (list): List of file shapes.
n_frames (int): Number of frames to load from each hdf5 file.
frame_index_stride (int): Interval between frames to load.
key (str, optional): Key of hdf5 dataset to grab data from. Defaults to "data/image".
initial_frame_axis (int, optional): Axis to iterate over. Defaults to 0.
additional_axes_iter (list, optional): Additional axes to iterate over in the dataset.
Defaults to None.
sort_files (bool, optional): Sort files by number. Defaults to True.
overlapping_blocks (bool, optional): Will take n_frames from sequence, then move by 1.
Defaults to False.
limit_n_frames (int, optional): Limit the number of frames to load from each file. This
means n_frames per data file will be used. These will be the first frames in the file.
Defaults to None.
Returns:
list: List of tuples with indices to extract images from hdf5 files.
(file_name, key, indices) with indices being a tuple of slices.
Example:
.. code-block:: python
[
(
"/folder/path_to_file.hdf5",
"data/image",
[range(0, 1), slice(None, 256, None), slice(None, 256, None)],
),
(
"/folder/path_to_file.hdf5",
"data/image",
[range(1, 2), slice(None, 256, None), slice(None, 256, None)],
),
...,
]
"""
if not limit_n_frames:
limit_n_frames = np.inf
assert len(file_paths) == len(file_shapes), "file_paths and file_shapes must have same length"
if additional_axes_iter:
# cannot contain initial_frame_axis
assert initial_frame_axis not in additional_axes_iter, (
"initial_frame_axis cannot be in additional_axes_iter. "
"We are already iterating over that axis."
)
else:
additional_axes_iter = []
if sort_files:
try:
# this is like an np.argsort, returns the indices that would sort the array
indices_sorting_file_paths = sorted(
range(len(file_paths)),
key=lambda i: int(re.findall(r"\d+", file_paths[i])[-2]),
)
file_paths = [file_paths[i] for i in indices_sorting_file_paths]
file_shapes = [file_shapes[i] for i in indices_sorting_file_paths]
except Exception:
log.warning("H5Generator: Could not sort file_paths by number.")
# block size with stride included
block_size = n_frames * frame_index_stride
if not overlapping_blocks:
block_step_size = block_size
else:
# now blocks overlap by n_frames - 1
block_step_size = 1
def axis_indices_files():
# For every file
for shape in file_shapes:
n_frames_in_file = shape[initial_frame_axis]
# Optionally limit frames to load from each file
n_frames_in_file = min(n_frames_in_file, limit_n_frames)
indices = [
range(i, i + block_size, frame_index_stride)
for i in range(0, n_frames_in_file - block_size + 1, block_step_size)
]
yield [indices]
indices = []
skipped_files = 0
for file, shape, axis_indices in zip(file_paths, file_shapes, list(axis_indices_files())):
# remove all the files that have empty list at initial_frame_axis
# this can happen if the file is too small to fit a block
if not axis_indices[0]: # initial_frame_axis is the first entry in axis_indices
skipped_files += 1
continue
if additional_axes_iter:
axis_indices += [range(shape[axis]) for axis in additional_axes_iter]
axis_indices = product(*axis_indices)
for axis_index in axis_indices:
full_indices = [slice(size) for size in shape]
for i, axis in enumerate([initial_frame_axis] + list(additional_axes_iter)):
full_indices[axis] = axis_index[i]
indices.append((file, key, full_indices))
if skipped_files > 0:
log.warning(
f"H5Generator: Skipping {skipped_files} files with not enough frames "
f"which is about {skipped_files / len(file_paths) * 100:.2f}% of the "
f"dataset. This can be fine if you expect set `n_frames` and "
"`frame_index_stride` to be high. Minimum frames in a file needs to be at "
f"least n_frames * frame_index_stride = {n_frames * frame_index_stride}. "
)
return indices
def _h5_reopen_on_io_error(
dataloader_obj: H5FileHandleCache,
file,
key,
indices,
retry_count,
**kwargs,
):
"""Reopen the file if an I/O error occurs.
Also removes the file from the cache and try to close file.
"""
file_name = indices[0]
try:
file_handle = dataloader_obj._file_handle_cache.pop(file_name, None)
if file_handle is not None:
file_handle.close()
except Exception:
pass
log.warning(
f"H5Generator: I/O error occurred while reading file {file_name}. "
f"Retry opening file. Retry count: {retry_count}."
)
[docs]
class H5Generator(Dataset):
"""H5Generator class for iterating over hdf5 files in an advanced way.
Mostly used internally, you might want to use the Dataloader class instead.
Loads one item at a time. Always outputs numpy arrays.
"""
def __init__(
self,
file_paths: List[str],
key: str = "data/image",
n_frames: int = 1,
shuffle: bool = True,
return_filename: bool = False,
limit_n_samples: int | None = None,
limit_n_frames: int | None = None,
seed: int | None = None,
cache: bool = False,
additional_axes_iter: tuple | None = None,
sort_files: bool = True,
overlapping_blocks: bool = False,
initial_frame_axis: int = 0,
insert_frame_axis: bool = True,
frame_index_stride: int = 1,
frame_axis: int = -1,
validate: bool = True,
**kwargs,
):
super().__init__(file_paths, key, validate=validate, **kwargs)
self.n_frames = int(n_frames)
self.frame_index_stride = int(frame_index_stride)
self.frame_axis = int(frame_axis)
self.insert_frame_axis = insert_frame_axis
self.initial_frame_axis = int(initial_frame_axis)
self.return_filename = return_filename
self.shuffle = shuffle
self.sort_files = sort_files
self.overlapping_blocks = overlapping_blocks
self.limit_n_samples = limit_n_samples
self.limit_n_frames = limit_n_frames
self.seed = seed
self.additional_axes_iter = additional_axes_iter or []
assert self.frame_index_stride > 0, (
f"`frame_index_stride` must be greater than 0, got {self.frame_index_stride}"
)
assert self.n_frames > 0, f"`n_frames` must be greater than 0, got {self.n_frames}"
# Extract some general information about the dataset
image_shapes = np.array(self.file_shapes)
image_shapes = np.delete(
image_shapes, (self.initial_frame_axis, *self.additional_axes_iter), axis=1
)
n_dims = len(image_shapes[0])
self.equal_file_shapes = np.all(image_shapes == image_shapes[0])
if not self.equal_file_shapes:
log.warning(
"H5Generator: Not all files have the same shape. "
"This can lead to issues when resizing images later...."
)
self.shape = np.array([None] * n_dims)
else:
self.shape = np.array(image_shapes[0])
if insert_frame_axis:
_frame_axis = map_negative_indices([frame_axis], len(self.shape) + 1)
self.shape = np.insert(self.shape, _frame_axis, 1)
if self.shape[frame_axis]:
self.shape[frame_axis] = self.shape[frame_axis] * n_frames
# Set random number generator
self.rng = np.random.default_rng(self.seed)
self.indices = generate_h5_indices(
file_paths=self.file_paths,
file_shapes=self.file_shapes,
n_frames=self.n_frames,
frame_index_stride=self.frame_index_stride,
key=self.key,
initial_frame_axis=self.initial_frame_axis,
additional_axes_iter=self.additional_axes_iter,
sort_files=self.sort_files,
overlapping_blocks=self.overlapping_blocks,
limit_n_frames=self.limit_n_frames,
)
if not self.shuffle:
log.warning("H5Generator: Not shuffling data.")
if limit_n_samples:
log.warning(
f"H5Generator: Limiting number of samples to {limit_n_samples} "
f"out of {len(self.indices)}"
)
self.indices = self.indices[:limit_n_samples]
self.shuffled_items = list(range(len(self.indices)))
# Retry count for I/O errors
self.retry_count = 0
# Create a cache for the data
self.cache = cache
self._data_cache = {}
def _get_single_item(self, idx):
# Check if the item is already in the cache
if self.cache and idx in self._data_cache:
return self._data_cache[idx]
# Get the data
file_name, key, indices = self.indices[idx]
file = self.get_file(file_name)
image = self.load(file, key, indices)
file_data = json_dumps(
{
"fullpath": file.filename,
"filename": file.stem,
"indices": indices,
}
)
if self.cache:
# Store the image and file data in the cache
self._data_cache[idx] = [image, file_data]
return image, file_data
def __getitem__(self, index):
image, file_data = self._get_single_item(self.shuffled_items[index])
if self.return_filename:
return image, file_data
else:
return image
[docs]
@retry_on_io_error(
max_retries=MAX_RETRY_ATTEMPTS,
initial_delay=INITIAL_RETRY_DELAY,
retry_action=_h5_reopen_on_io_error,
)
def load(self, file: File, key: str, indices: tuple | str):
"""Extract data from hdf5 file.
Args:
file_name (str): name of the file to extract image from.
key (str): key of the hdf5 dataset to grab data from.
indices (tuple): indices to extract image from (tuple of slices)
Returns:
np.ndarray: image extracted from hdf5 file and indexed by indices.
"""
try:
images = file.load_data(key, indices)
except (OSError, IOError):
# Let the decorator handle I/O errors
raise
except Exception as exc:
# For non-I/O errors, provide detailed context
raise ValueError(
f"Could not load image at index {indices} "
f"and file {file.name} of shape {file[key].shape}"
) from exc
# stack frames along frame_axis
if self.insert_frame_axis:
# move frames axis to self.frame_axis
initial_frame_axis = self.initial_frame_axis
if self.additional_axes_iter:
# offset initial_frame_axis if we have additional axes that are before
# the initial_frame_axis
additional_axes_before = sum(
axis < self.initial_frame_axis for axis in self.additional_axes_iter
)
initial_frame_axis = initial_frame_axis - additional_axes_before
images = np.moveaxis(images, initial_frame_axis, self.frame_axis)
else:
# append frames to existing axis
images = np.concatenate(images, axis=self.frame_axis)
return images
def _shuffle(self):
self.rng.shuffle(self.shuffled_items)
log.info("H5Generator: Shuffled data.")
def __len__(self):
return len(self.indices)
[docs]
def iterator(self):
"""Generator that yields images from the hdf5 files."""
if self.shuffle:
self._shuffle()
for idx in range(len(self)):
yield self[idx]
def __iter__(self):
"""
Generator that yields images from the hdf5 files.
"""
return self.iterator()
def __repr__(self):
return (
f"<{self.__class__.__name__} at 0x{id(self):x}: "
f"{len(self)} batches, n_frames={self.n_frames}, key='{self.key}', "
f"shuffle={self.shuffle}, file_paths={len(self.file_paths)}>"
)
def __str__(self):
return (
f"H5Generator with {len(self)} batches from {len(self.file_paths)} files "
f"(key='{self.key}')"
)
[docs]
def summary(self):
"""Return a string with dataset statistics and per-directory breakdown."""
total_samples = len(self.indices)
file_names = [idx[0] for idx in self.indices]
# Try to infer directories from file_names
directories = sorted({str(Path(f).parent) for f in file_names})
samples_per_dir = count_samples_per_directory(file_names, directories)
parts = [f"H5Generator with {total_samples} total samples:"]
for dir_path, count in samples_per_dir.items():
percentage = (count / total_samples) * 100 if total_samples else 0
parts.append(f" {dir_path}: {count} samples ({percentage:.1f}%)")
print("\n".join(parts))