"""Gaussian Mixture Model (GMM) implementation"""
import keras
import numpy as np
from keras import ops
from zea.models.generative import GenerativeModel
from zea.tensor_ops import linear_sum_assignment
[docs]
class GaussianMixtureModel(GenerativeModel):
"""
Gaussian Mixture Model fitted with EM algorithm.
Args:
n_components: Number of mixture components.
n_features: Number of features (dimensions).
max_iter: Maximum number of EM steps.
tol: Convergence tolerance.
seed: Random seed for reproducibility.
Example:
```python
gmm = GaussianMixtureModel(n_components=2, n_features=2)
gmm.fit(data, max_iter=100)
samples = gmm.sample(100)
```
"""
def __init__(self, n_components=2, n_features=1, tol=1e-4, seed=None):
self.n_components = n_components
self.n_features = n_features
self.tol = tol
self.seed = seed
self._initialized = False
self.means = None # (n_components, n_features)
self.vars = None # (n_components, n_features)
self.pi = None # (n_components,)
def _initialize(self, X):
# X: (n_samples, n_features)
n_samples = ops.shape(X)[0]
n_features = ops.shape(X)[1]
chosen = []
# Pick the first mean randomly
idx = ops.cast(
keras.random.uniform(
shape=(),
minval=0,
maxval=n_samples,
seed=self.seed,
),
"int32",
)
chosen.append(idx)
for _ in range(1, self.n_components):
# Gather chosen means so far
chosen_means = ops.stack(
[ops.take(X, i, axis=0) for i in chosen], axis=0
) # (len(chosen), n_features)
# Compute distances from all points to each chosen mean
# (n_samples, len(chosen), n_features)
diffs = ops.expand_dims(X, 1) - ops.expand_dims(chosen_means, 0)
dists = ops.sqrt(ops.sum(diffs**2, axis=-1)) # (n_samples, len(chosen))
min_dists = ops.min(dists, axis=1) # (n_samples,)
idx = ops.argmax(min_dists, axis=0)
chosen.append(idx)
means = ops.stack([ops.take(X, i, axis=0) for i in chosen], axis=0)
self.means = means
# Initialize variances to variance of data
var = ops.var(X, axis=0)
self.vars = ops.ones((self.n_components, n_features)) * var
# Initialize mixture weights uniformly
self.pi = ops.ones((self.n_components,)) / self.n_components
self._initialized = True
def _e_step(self, X):
# X: (n_samples, n_features)
X_exp = ops.expand_dims(X, axis=1) # (n_samples, 1, n_features)
means = ops.expand_dims(self.means, axis=0) # (1, n_components, n_features)
vars_ = ops.expand_dims(self.vars, axis=0) # (1, n_components, n_features)
pi = self.pi # (n_components,)
# Compute log Gaussian pdf for each component
log_prob = -0.5 * ops.sum(
ops.log(2 * np.pi * vars_) + ((X_exp - means) ** 2) / vars_, axis=-1
) # (n_samples, n_components)
# Add log mixture weights
log_prob = log_prob + ops.log(pi)
# Normalize to get responsibilities
log_prob_norm = log_prob - ops.logsumexp(log_prob, axis=1, keepdims=True)
gamma = ops.exp(log_prob_norm) # (n_samples, n_components)
return gamma # responsibilities
def _m_step(self, X, gamma):
# X: (n_samples, n_features)
# gamma: (n_samples, n_components)
Nk = ops.sum(gamma, axis=0) # (n_components,)
# Update means
means = ops.sum(
ops.expand_dims(gamma, -1) * ops.expand_dims(X, 1), axis=0
) / ops.expand_dims(Nk, -1)
# Update variances
X_exp = ops.expand_dims(X, axis=1) # (n_samples, 1, n_features)
means_exp = ops.expand_dims(means, axis=0) # (1, n_components, n_features)
vars_ = ops.sum(gamma[..., None] * (X_exp - means_exp) ** 2, axis=0) / ops.expand_dims(
Nk, -1
)
# Update mixture weights
pi = Nk / ops.sum(Nk)
return means, vars_, pi
[docs]
def fit(self, data, max_iter=100, verbose=0, **kwargs):
X = ops.convert_to_tensor(data, dtype="float32")
if not self._initialized:
self._initialize(X)
prev_ll = None
progbar = keras.utils.Progbar(max_iter, verbose=verbose)
for i in range(max_iter):
# E-step
gamma = self._e_step(X)
# M-step
means, vars_, pi = self._m_step(X, gamma)
# Compute log-likelihood
self.means, self.vars, self.pi = means, vars_, pi
ll = ops.sum(ops.log(ops.sum(self._component_pdf(X) * self.pi, axis=1)))
if verbose:
progbar.update(i + 1, values=[("log-likelihood", float(ll))])
if prev_ll is not None and abs(float(ll) - float(prev_ll)) < self.tol:
if verbose:
print(f"\nConverged at iter {i}")
break
prev_ll = ll
def _component_pdf(self, X):
# X: (n_samples, n_features)
X_exp = ops.expand_dims(X, axis=1) # (n_samples, 1, n_features)
means = ops.expand_dims(self.means, axis=0) # (1, n_components, n_features)
vars_ = ops.expand_dims(self.vars, axis=0) # (1, n_components, n_features)
# Gaussian PDF (no mixture weights)
norm = ops.prod(ops.sqrt(2 * np.pi * vars_), axis=-1)
exp_term = ops.exp(-0.5 * ops.sum(((X_exp - means) ** 2) / vars_, axis=-1))
return exp_term / norm # (n_samples, n_components)
[docs]
def sample(self, n_samples=1, seed=None, **kwargs):
# Sample component indices
comp_idx = keras.random.categorical(ops.log(self.pi[None, :]), n_samples, seed=seed)
comp_idx = ops.squeeze(comp_idx, axis=0)
means = ops.take(self.means, comp_idx, axis=0)
vars_ = ops.take(self.vars, comp_idx, axis=0)
eps = keras.random.normal(ops.shape(means), seed=seed)
samples = means + eps * ops.sqrt(vars_)
return samples
[docs]
def posterior_sample(self, measurements, n_samples=1, seed=None, **kwargs):
"""
Sample component indices from the posterior p(z|x) for each measurement.
Args:
measurements: Input data, shape (batch, n_features).
n_samples: Number of posterior samples per measurement.
seed: Random seed.
Returns:
Component indices, shape (batch, n_samples).
"""
X = ops.convert_to_tensor(measurements, dtype="float32")
gamma = self._e_step(X) # (batch, n_components)
# Sample n_samples times for each measurement
comp_idx = keras.random.categorical(
ops.log(gamma), n_samples, seed=seed
) # (batch, n_samples)
# Return as (batch, n_samples)
return comp_idx
[docs]
def log_density(self, data, **kwargs):
X = ops.convert_to_tensor(data, dtype="float32")
pdf = ops.sum(self._component_pdf(X) * self.pi, axis=1)
return ops.log(pdf)
[docs]
def match_means_covariances(means, true_means, covs, true_covs):
"""Match estimated means/covs to true ones.
Uses greedy minimal distance assignment.
Args:
means: Estimated means (n_components, n_features).
true_means: True means (n_components, n_features).
covs: Estimated covariances (n_components, n_features, n_features).
true_covs: True covariances (n_components, n_features, n_features).
Returns:
means_matched: Matched estimated means.
true_means_matched: Matched true means.
covs_matched: Matched estimated covariances.
true_covs_matched: Matched true covariances.
"""
diff = ops.expand_dims(means, 1) - ops.expand_dims(true_means, 0)
cost = ops.sqrt(ops.sum(diff**2, axis=-1))
row_ind, col_ind = linear_sum_assignment(cost)
means_matched = ops.take(means, row_ind, axis=0)
true_means_matched = ops.take(true_means, col_ind, axis=0)
covs_matched = [covs[i] for i in row_ind]
true_covs_matched = [true_covs[j] for j in col_ind]
return means_matched, true_means_matched, covs_matched, true_covs_matched