zea.models.gmm¶
Gaussian Mixture Model (GMM) implementation
Functions
|
Match estimated means/covs to true ones. |
Classes
|
Gaussian Mixture Model fitted with EM algorithm. |
- class zea.models.gmm.GaussianMixtureModel(n_components=2, n_features=1, tol=0.0001, seed=None)[source]¶
Bases:
GenerativeModel
Gaussian Mixture Model fitted with EM algorithm.
- Parameters:
n_components – Number of mixture components.
n_features – Number of features (dimensions).
max_iter – Maximum number of EM steps.
tol – Convergence tolerance.
seed – Random seed for reproducibility.
Example:
`python gmm = GaussianMixtureModel(n_components=2, n_features=2) gmm.fit(data, max_iter=100) samples = gmm.sample(100) `
- fit(data, max_iter=100, verbose=0, **kwargs)[source]¶
Fit the model to the data.
- Parameters:
data – The data to fit the model to.
**kwargs – Additional arguments to pass to the fitting procedure.
- log_density(data, **kwargs)[source]¶
Compute the log-density $log p(x)$ of the data under the model.
- Parameters:
data – The data $x$ to compute the log-density for.
**kwargs – Additional arguments.
- Returns:
Log-density $log p(x)$ of the data.
- posterior_sample(measurements, n_samples=1, seed=None, **kwargs)[source]¶
Sample component indices from the posterior p(z|x) for each measurement.
- Parameters:
measurements – Input data, shape (batch, n_features).
n_samples – Number of posterior samples per measurement.
seed – Random seed.
- Returns:
Component indices, shape (batch, n_samples).
- zea.models.gmm.match_means_covariances(means, true_means, covs, true_covs)[source]¶
Match estimated means/covs to true ones.
Uses greedy minimal distance assignment.
- Parameters:
means – Estimated means (n_components, n_features).
true_means – True means (n_components, n_features).
covs – Estimated covariances (n_components, n_features, n_features).
true_covs – True covariances (n_components, n_features, n_features).
- Returns:
Matched estimated means. true_means_matched: Matched true means. covs_matched: Matched estimated covariances. true_covs_matched: Matched true covariances.
- Return type:
means_matched