zea.models.gmm

Gaussian Mixture Model (GMM) implementation

Functions

match_means_covariances(means, true_means, ...)

Match estimated means/covs to true ones.

Classes

GaussianMixtureModel([n_components, ...])

Gaussian Mixture Model fitted with EM algorithm.

class zea.models.gmm.GaussianMixtureModel(n_components=2, n_features=1, tol=0.0001, seed=None)[source]

Bases: GenerativeModel

Gaussian Mixture Model fitted with EM algorithm.

Parameters:
  • n_components – Number of mixture components.

  • n_features – Number of features (dimensions).

  • max_iter – Maximum number of EM steps.

  • tol – Convergence tolerance.

  • seed – Random seed for reproducibility.

Example: `python gmm = GaussianMixtureModel(n_components=2, n_features=2) gmm.fit(data, max_iter=100) samples = gmm.sample(100) `

fit(data, max_iter=100, verbose=0, **kwargs)[source]

Fit the model to the data.

Parameters:
  • data – The data to fit the model to.

  • **kwargs – Additional arguments to pass to the fitting procedure.

log_density(data, **kwargs)[source]

Compute the log-density $log p(x)$ of the data under the model.

Parameters:
  • data – The data $x$ to compute the log-density for.

  • **kwargs – Additional arguments.

Returns:

Log-density $log p(x)$ of the data.

posterior_sample(measurements, n_samples=1, seed=None, **kwargs)[source]

Sample component indices from the posterior p(z|x) for each measurement.

Parameters:
  • measurements – Input data, shape (batch, n_features).

  • n_samples – Number of posterior samples per measurement.

  • seed – Random seed.

Returns:

Component indices, shape (batch, n_samples).

sample(n_samples=1, seed=None, **kwargs)[source]

Draw samples $x sim p(x)$ from the model.

Parameters:
  • n_samples – Number of samples to generate.

  • **kwargs – Additional arguments to pass to the sampling procedure.

Returns:

Samples $x$ from the model distribution $p(x)$.

zea.models.gmm.match_means_covariances(means, true_means, covs, true_covs)[source]

Match estimated means/covs to true ones.

Uses greedy minimal distance assignment.

Parameters:
  • means – Estimated means (n_components, n_features).

  • true_means – True means (n_components, n_features).

  • covs – Estimated covariances (n_components, n_features, n_features).

  • true_covs – True covariances (n_components, n_features, n_features).

Returns:

Matched estimated means. true_means_matched: Matched true means. covs_matched: Matched estimated covariances. true_covs_matched: Matched true covariances.

Return type:

means_matched