Source code for zea.backend.tensorflow.dataloader

"""HDF5 Tensorflow dataloader.

Convenient way of loading data from hdf5 files in a ML pipeline.
"""

from functools import partial
from typing import List

import keras
import tensorflow as tf
from keras.src.trainers.data_adapters import TFDatasetAdapter

from zea.data.dataloader import H5Generator
from zea.data.layers import Resizer
from zea.utils import find_methods_with_return_type, translate

METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")


[docs] class TFDatasetToKeras(TFDatasetAdapter): """Tensorflow Dataset to Keras Dataset. This class wraps a tf.data.Dataset object and allows it to be used with Keras backends. """ def __init__(self, dataset): super().__init__(dataset) def __iter__(self): backend = keras.backend.backend() if backend == "tensorflow": return iter(self.get_tf_dataset()) elif backend == "jax": return self.get_jax_iterator() elif backend == "torch": return iter(self.get_torch_dataloader()) elif backend == "numpy": return self.get_numpy_iterator() else: raise ValueError( f"Unsupported backend: {backend}. " "Please use one of the following: 'tensorflow', 'jax', 'torch', 'numpy'." ) def __len__(self): return self.num_batches def __getattr__(self, name): # Delegate all calls to self._dataset, and wraps the result in TFDatasetToKeras if name in METHODS_THAT_RETURN_DATASET: def method(*args, **kwargs): result = getattr(self._dataset, name)(*args, **kwargs) return TFDatasetToKeras(result) return method else: return getattr(self._dataset, name)
[docs] class H5GeneratorTF(H5Generator): """Adds a tensorflow dtype property and output_signature to the H5Generator class.""" @property def tensorflow_dtype(self): """ Extracts one image from the dataset to get the dtype. Converts it to a tensorflow dtype. """ out = next(self.iterator()) if self.return_filename: out = out[0] dtype = out.dtype if "float" in str(dtype): dtype = tf.float32 elif "complex" in str(dtype): dtype = tf.complex64 elif "uint8" in str(dtype): dtype = tf.uint8 else: raise ValueError(f"Unsupported dtype: {dtype}") return dtype @property def output_signature(self): """ Get the output signature of the generator as a tensorflow `TensorSpec`. This is useful for creating a `tf.data.Dataset` from the generator. """ output_signature = tf.TensorSpec(shape=self.shape, dtype=self.tensorflow_dtype) if self.return_filename: output_signature = ( output_signature, tf.TensorSpec(shape=(), dtype=tf.string), ) return output_signature
def _assert_image_range(images, image_range): # Check if there are outliers in the image range minval = tf.reduce_min(images) maxval = tf.reduce_max(images) _msg = f"Image range {image_range} is not in the range of the data {minval} - {maxval}" tf.debugging.assert_greater_equal( minval, tf.cast(image_range[0], minval.dtype), message=_msg, ) tf.debugging.assert_less_equal( maxval, tf.cast(image_range[1], maxval.dtype), message=_msg, ) return images
[docs] def make_dataloader( file_paths: List[str], batch_size: int, key: str = "data/image", n_frames: int = 1, shuffle: bool = True, return_filename: bool = False, limit_n_samples: int | None = None, limit_n_frames: int | None = None, seed: int | None = None, drop_remainder: bool = False, resize_type: str | None = None, resize_axes: tuple | None = None, resize_kwargs: dict | None = None, image_size: tuple | None = None, image_range: tuple | None = None, normalization_range: tuple | None = None, dataset_repetitions: int | None = None, cache: bool = False, additional_axes_iter: tuple | None = None, sort_files: bool = True, overlapping_blocks: bool = False, augmentation: callable = None, assert_image_range: bool = True, clip_image_range: bool = False, initial_frame_axis: int = 0, insert_frame_axis: bool = True, frame_index_stride: int = 1, frame_axis: int = -1, validate: bool = True, prefetch: bool = True, shard_index: int | None = None, num_shards: int = 1, wrap_in_keras: bool = True, **kwargs, ) -> tf.data.Dataset: """Creates a ``tf.data.Dataset`` from .hdf5 files in the specified directory or directories. Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory`` but for .hdf5 files. Saves a dataset_info.yaml file in the directory with information about the dataset. This file is used to load the dataset later on, which speeds up the initial loading of the dataset for very large datasets. Does the following in order to load a dataset: - Find all .hdf5 files in the director(ies) - Load the data from each file using the specified key - Apply the following transformations in order (if specified): - limit_n_samples - cache - shuffle - shard - add channel dim - assert_image_range - clip_image_range - resize - repeat - batch - normalize - augmentation - prefetch - tf -> keras tensor Args: file_paths (str or list): Path(s) to the folder(s) or h5 file(s) to load. batch_size (int): Batch the dataset. key (str): The key to access the HDF5 dataset. n_frames (int, optional): Number of frames to load from each hdf5 file. Defaults to 1. These frames are stacked along the last axis (channel). shuffle (bool, optional): Shuffle dataset. return_filename (bool, optional): Return file name with image. Defaults to False. limit_n_samples (int, optional): Take only a subset of samples. Useful for debugging. Defaults to None. limit_n_frames (int, optional): Limit the number of frames to load from each file. This means n_frames per data file will be used. These will be the first frames in the file. Defaults to None. seed (int, optional): Random seed of shuffle. drop_remainder (bool, optional): Whether the last batch should be dropped. resize_type (str, optional): Resize type. Defaults to 'center_crop'. Can be 'center_crop', 'random_crop' or 'resize'. resize_axes (tuple, optional): Axes to resize along. Should be of length 2 (height, width) as resizing function only supports 2D resizing / cropping. Should only be set when your data is more than (h, w, c). Defaults to None. Note that it considers the axes after inserting the frame axis. resize_kwargs (dict, optional): Kwargs for the resize function. image_size (tuple, optional): Resize images to image_size. Should be of length two (height, width). Defaults to None. image_range (tuple, optional): Image range. Defaults to (0, 255). Will always translate from specified image range to normalization range. If image_range is set to None, no normalization will be done. Note that it does not clip to the image range, so values outside the image range will be outside the normalization range! normalization_range (tuple, optional): Normalization range. Defaults to (0, 1). See image_range for more info! dataset_repetitions (int, optional): Repeat dataset. Note that this happens after sharding, so the shard will be repeated. Defaults to None. cache (bool, optional): Cache dataset to RAM. additional_axes_iter (tuple, optional): Additional axes to iterate over in the dataset. Defaults to None, in that case we only iterate over the first axis (we assume those contain the frames). sort_files (bool, optional): Sort files by number. Defaults to True. overlapping_blocks (bool, optional): If True, blocks overlap by n_frames - 1. Defaults to False. Has no effect if n_frames = 1. augmentation (keras.Sequential, optional): Keras augmentation layer. assert_image_range (bool, optional): Assert that the image range is within the specified image range. Defaults to True. clip_image_range (bool, optional): Clip the image range to the specified image range. Defaults to False. initial_frame_axis (int, optional): Axis where in the files the frames are stored. Defaults to 0. insert_frame_axis (bool, optional): If True, new dimension to stack frames along will be created. Defaults to True. In that case frames will be stacked along existing dimension (frame_axis). frame_index_stride (int, optional): Interval between frames to load. Defaults to 1. If n_frames > 1, a lower frame rate can be simulated. frame_axis (int, optional): Dimension to stack frames along. Defaults to -1. If insert_frame_axis is True, this will be the new dimension to stack frames along. validate (bool, optional): Validate if the dataset adheres to the zea format. Defaults to True. prefetch (bool, optional): Prefetch the dataset. Defaults to True. shard_index (int, optional): Index which part of the dataset should be selected. Can only be used if num_shards is specified. Defaults to None. See for info: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard num_shards (int, optional): This is used to divide the dataset into ``num_shards`` parts. Sharding happens before all other operations. Defaults to 1. See for info: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard wrap_in_keras (bool, optional): Wrap dataset in TFDatasetToKeras. Defaults to True. If True, will convert the dataset that returns backend tensors. Returns: tf.data.Dataset: The constructed dataset. """ # Setup if normalization_range is not None: assert image_range is not None, ( "If normalization_range is set, image_range must be set as well." ) resize_kwargs = resize_kwargs or {} if num_shards > 1: assert shard_index is not None, "shard_index must be specified" assert shard_index < num_shards, "shard_index must be less than num_shards" assert shard_index >= 0, "shard_index must be greater than or equal to 0" image_extractor = H5GeneratorTF( file_paths, key, n_frames=n_frames, frame_index_stride=frame_index_stride, frame_axis=frame_axis, insert_frame_axis=insert_frame_axis, initial_frame_axis=initial_frame_axis, return_filename=return_filename, shuffle=shuffle, sort_files=sort_files, overlapping_blocks=overlapping_blocks, limit_n_samples=limit_n_samples, limit_n_frames=limit_n_frames, seed=seed, additional_axes_iter=additional_axes_iter, cache=cache, validate=validate, **kwargs, ) # Create dataset dataset = tf.data.Dataset.from_generator( image_extractor, output_signature=image_extractor.output_signature ) # Assert cardinality dataset = dataset.apply(tf.data.experimental.assert_cardinality(len(image_extractor))) # Shard dataset if num_shards > 1: dataset = dataset.shard(num_shards, shard_index) # Define helper function to apply map function to dataset def dataset_map(dataset, func): """Does not apply func to filename.""" if return_filename: return dataset.map( lambda x, filename: (func(x), filename), num_parallel_calls=tf.data.AUTOTUNE, ) else: return dataset.map(func, num_parallel_calls=tf.data.AUTOTUNE) # add channel dim if len(image_extractor.shape) != 3: dataset = dataset_map(dataset, lambda x: tf.expand_dims(x, axis=-1)) # Clip to image range if clip_image_range and image_range is not None: dataset = dataset_map( dataset, partial( tf.clip_by_value, clip_value_min=image_range[0], clip_value_max=image_range[1], ), ) # Check if there are outliers in the image range if assert_image_range and image_range is not None: dataset = dataset_map(dataset, partial(_assert_image_range, image_range=image_range)) if image_size or resize_type: if frame_axis != -1: assert resize_axes is not None, ( "Resizing only works with frame_axis = -1. Alternatively, " "you can specify resize_axes." ) # Let resizer handle the assertions. resizer = Resizer( image_size=image_size, resize_type=resize_type, resize_axes=resize_axes, seed=seed, **resize_kwargs, ) dataset = dataset_map(dataset, resizer) # repeat dataset if needed (used for smaller datasets) if dataset_repetitions: dataset = dataset.repeat(dataset_repetitions) # batch if batch_size: dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) # normalize if normalization_range is not None: dataset = dataset_map( dataset, lambda x: translate(x, image_range, normalization_range), ) # augmentation if augmentation is not None: dataset = dataset_map(dataset, augmentation) # prefetch if prefetch: dataset = dataset.prefetch(tf.data.AUTOTUNE) if wrap_in_keras: dataset = TFDatasetToKeras(dataset) return dataset