zea.models.diffusion

Diffusion models

Classes

DPS(diffusion_model, operator[, disable_jit])

Diffusion Posterior Sampling guidance.

DiffusionGuidance(diffusion_model, operator)

Base class for diffusion guidance methods.

DiffusionModel(*args, **kwargs)

Implementation of a diffusion generative model.

class zea.models.diffusion.DPS(diffusion_model, operator, disable_jit=False)[source]

Bases: DiffusionGuidance

Diffusion Posterior Sampling guidance.

compute_error(noisy_images, measurements, noise_rates, signal_rates, omega, **kwargs)[source]

Compute measurement error for diffusion posterior sampling.

Parameters:
  • noisy_images – Noisy images.

  • measurement – Target measurement.

  • operator – Forward operator.

  • noise_rates – Current noise rates.

  • signal_rates – Current signal rates.

  • omega – Weight for the measurement error.

  • **kwargs – Additional arguments for the operator.

Returns:

Tuple of (measurement_error, (pred_noises, pred_images))

setup()[source]

Setup the autograd function for DPS.

class zea.models.diffusion.DiffusionGuidance(diffusion_model, operator, disable_jit=False)[source]

Bases: ABC, Object

Base class for diffusion guidance methods.

abstract setup()[source]

Setup the guidance function. Should be implemented by subclasses.

class zea.models.diffusion.DiffusionModel(*args, **kwargs)[source]

Bases: DeepGenerativeModel

Implementation of a diffusion generative model. Heavily inspired from https://keras.io/examples/generative/ddim/

call(inputs, training=False, network=None, **kwargs)[source]

Calls the score network.

If network is not provided, will use the exponential moving average network if training is False, otherwise the regular network.

denoise(noisy_images, noise_rates, signal_rates, training, network=None)[source]

Predict noise component and calculate the image component using it.

diffusion_schedule(diffusion_times)[source]

Cosine diffusion schedule https://arxiv.org/abs/2102.09672

Parameters:

diffusion_times – tensor with diffusion times in [0, 1]

Returns:

tensor with noise rates signal_rates: tensor with signal rates

according to: - x_t = signal_rate * x_0 + noise_rate * noise - x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * noise

or with stochastic sampling: - x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t - sigma_t^2) * noise + sigma_t * epsilon

where: - sigma_t = sqrt((1 - alpha_t) / (1 - alpha_{t+1})) * sqrt(1 - alpha_{t+1} / alpha_t)

Return type:

noise_rates

Note

t+1 = previous time step t = current time step

get_config()[source]

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

linear_diffusion_schedule(diffusion_times)[source]

Create a linear diffusion schedule

log_likelihood(data, **kwargs)[source]

Approximate log-likelihood of the data under the model.

Parameters:
  • data – Data to compute log-likelihood for.

  • **kwargs – Additional arguments.

Returns:

Approximate log-likelihood.

property metrics

Metrics for training.

posterior_sample(measurements, n_samples=1, n_steps=20, initial_step=0, initial_samples=None, seed=None, **kwargs)[source]

Sample from the posterior distribution given measurements.

Parameters:
  • measurements – Input measurements. Typically of shape (batch_size, *input_shape).

  • n_samples – Number of posterior samples to generate. Will generate n_samples samples for each measurement in the measurements batch.

  • n_steps – Number of diffusion steps.

  • initial_step – Initial step to start from. Can warm start the diffusion process with a partially noised image, thereby skipping part of the diffusion process. Initial step closer to n_steps, will result in a shorter diffusion process (i.e. less noise added to the initial image). A value of 0 means that the diffusion process starts from pure noise.

  • initial_samples – Optional initial samples to start from. If provided, these samples will be used as the starting point for the diffusion process. Only used if initial_step is greater than 0. Must be of shape (batch_size, n_samples, *input_shape).

  • seed – Random seed generator.

  • **kwargs – Additional arguments.

Returns:

(batch_size, n_samples, *input_shape).

Return type:

Posterior samples p(x|y), of shape

prepare_diffusion(diffusion_steps, initial_step, verbose, disable_jit=False)[source]

Prepare the diffusion process.

This method sets up the parameters for the diffusion process, including validation of the initial step and calculation of the step size.

prepare_schedule(base_diffusion_times, initial_noise, initial_samples, initial_step, step_size)[source]

Prepare the diffusion schedule.

This method sets up the initial noisy images based on the provided initial noise and samples. It handles the case where the initial step is greater than 0, allowing for the use of partially noised images for initialization of the diffusion process.

Parameters:
  • base_diffusion_times – Base diffusion times.

  • initial_noise – Initial noise tensor.

  • initial_samples – Optional initial samples to start from.

  • initial_step – Initial step to start from.

  • step_size – Step size for the diffusion process.

Returns:

Noisy images after the initial step.

Return type:

next_noisy_images

reverse_conditional_diffusion(measurements, initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=False, track_progress_type='x_0', disable_jit=False, **kwargs)[source]

Reverse diffusion process conditioned on some measurement.

Effectively performs diffusion posterior sampling p(x_0 | y).

Parameters:
  • measurements – Conditioning data.

  • initial_noise – Initial noise tensor.

  • diffusion_steps (int) – Number of diffusion steps.

  • initial_samples – Optional initial samples to start from.

  • initial_step (int) – Initial step to start from.

  • stochastic_sampling (bool) – Whether to use stochastic sampling (DDPM).

  • seed – Random seed generator.

  • verbose (bool) – Whether to show a progress bar.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Type of progress tracking (“x_0” or “x_t”).

  • **kwargs – Additional arguments. These are passed to the guidance function and the operator. Examples are omega, mask, etc.

Returns:

Generated images.

reverse_diffusion(initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=True, track_progress_type='x_0', disable_jit=False, training=False, network_type=None)[source]

Reverse diffusion process to generate images from noise.

Parameters:
  • initial_noise – Initial noise tensor.

  • diffusion_steps (int) – Number of diffusion steps.

  • initial_samples – Optional initial samples to start from.

  • initial_step (int) – Initial step to start from.

  • stochastic_sampling (bool) – Whether to use stochastic sampling (DDPM).

  • seed (SeedGenerator | None) – Random seed generator.

  • verbose (bool) – Whether to show a progress bar.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Type of progress tracking (“x_0” or “x_t”).

  • disable_jit (bool) – Whether to disable JIT compilation.

  • training (bool) – Whether to use the training mode of the network.

  • network_type (Literal[None, 'main', 'ema']) – Which network to use (“main” or “ema”). If None, uses the network based on the training argument.

Returns:

Generated images.

reverse_diffusion_step(shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False)[source]

A single reverse diffusion step.

Parameters:
  • shape – Shape of the input tensor.

  • pred_images – Predicted images.

  • pred_noises – Predicted noises.

  • signal_rates – Current signal rates.

  • next_signal_rates – Next signal rates.

  • next_noise_rates – Next noise rates.

  • seed – Random seed generator.

  • stochastic_sampling – Whether to use stochastic sampling (DDPM).

Returns:

Noisy images after the reverse diffusion step.

Return type:

next_noisy_images

sample(n_samples=1, n_steps=20, seed=None, **kwargs)[source]

Sample from the model.

Parameters:
  • n_samples – Number of samples to generate.

  • n_steps – Number of diffusion steps.

  • seed – Random seed generator.

  • **kwargs – Additional arguments.

Returns:

Generated samples.

start_track_progress(diffusion_steps)[source]

Initialize the progress tracking for the diffusion process.

For diffusion animation we keep track of the diffusion progress. For large number of steps, we do not store all the images due to memory constraints.

store_progress(step, track_progress_type, next_noisy_images, pred_images)[source]

Store the progress of the diffusion process.

Parameters:
  • step – Current diffusion step.

  • track_progress_type – Type of progress tracking (“x_0” or “x_t”).

  • next_noisy_images – Noisy images after the current step.

  • pred_images – Predicted images.

Notes

  • x_0 is considered the predicted image (aka Tweedie estimate)

  • x_t is the noisy intermediate image

train_step(data)[source]

Custom train step so we can call model.fit() on the diffusion model. .. note:: - Only implemented for the TensorFlow backend.