zea.models.generative

Generative models for zea.

Classes

DeepGenerativeModel(*args, **kwargs)

Base class for deep generative models.

GenerativeModel()

Abstract base class for generative models.

class zea.models.generative.DeepGenerativeModel(*args, **kwargs)[source]

Bases: BaseModel, GenerativeModel

Base class for deep generative models.

Inherits from both GenerativeModel and BaseModel to combine generative capabilities with Keras model functionality.

class zea.models.generative.GenerativeModel[source]

Bases: ABC

Abstract base class for generative models.

fit(data, **kwargs)[source]

Fit the model to the data.

Parameters:
  • data – The data to fit the model to.

  • **kwargs – Additional arguments to pass to the fitting procedure.

log_density(data, **kwargs)[source]

Compute the log-density $log p(x)$ of the data under the model.

Parameters:
  • data – The data $x$ to compute the log-density for.

  • **kwargs – Additional arguments.

Returns:

Log-density $log p(x)$ of the data.

posterior_sample(measurements, n_samples=1, **kwargs)[source]

Draw samples $z sim p(z mid x)$ from the posterior given measurements.

Parameters:
  • measurements – The measurements $x$ to condition the posterior on.

  • n_samples – Number of posterior samples to generate. This will add an additional dimension to the output. For instance, if measurements has shape (batch_size, …), the output will have shape (batch_size, n_samples, …).

  • **kwargs – Additional arguments to pass to the sampling procedure.

Returns:

Samples $z$ from the posterior $p(z mid x)$.

sample(n_samples=1, **kwargs)[source]

Draw samples $x sim p(x)$ from the model.

Parameters:
  • n_samples – Number of samples to generate.

  • **kwargs – Additional arguments to pass to the sampling procedure.

Returns:

Samples $x$ from the model distribution $p(x)$.