Learning Objectives
By the end of this section, you will be able to:
📚 Core Knowledge
- • Explain the E-step and M-step of the EM algorithm
- • Derive the EM updates for Gaussian Mixture Models
- • Understand why EM always increases likelihood
- • Connect EM to the Evidence Lower Bound (ELBO)
🔧 Practical Skills
- • Implement EM for mixture models from scratch
- • Diagnose convergence issues in EM
- • Choose appropriate initialization strategies
- • Apply EM to real-world clustering problems
🧠 Deep Learning Connections
- • Variational Autoencoders (VAEs): EM inspiration for amortized inference
- • Latent Variable Models: Foundation for generative modeling
- • Semi-supervised Learning: Handling partially labeled data
- • Missing Data Imputation: Principled approaches to incomplete datasets
Where You'll Apply This: Clustering with GMM, topic modeling (LDA), hidden Markov models, image segmentation, semi-supervised learning, and any scenario with latent (hidden) variables or missing data.
The Big Picture
The Expectation-Maximization (EM) algorithm, introduced by Dempster, Laird, and Rubin in their landmark 1977 paper, is one of the most influential algorithms in statistics and machine learning. It provides an elegant solution to a fundamental problem: How do we find maximum likelihood estimates when some data is missing or hidden?
Historical Context
While the specific techniques existed earlier, Dempster, Laird & Rubin (1977) unified them under the EM framework, proving convergence guarantees and revealing deep connections to information theory. The paper has been cited over 70,000 times!
EM is particularly powerful because it:
- Handles latent variables: Hidden cluster assignments, missing observations, unobserved states
- Always improves: Each iteration is guaranteed to increase (or maintain) the likelihood
- Simple to implement: Often reduces complex optimization to closed-form updates
- Connects theory and practice: Bridges MLE, Bayesian inference, and information theory
The Missing Data Problem
Consider a dataset where some information is hidden or latent. This happens in many real scenarios:
Mixture Models
We observe data points but don't know which cluster (component) each point belongs to. The cluster assignments are latent variables.
Missing Values
Some entries in our dataset are missing (sensors failed, survey questions skipped). We want to estimate parameters despite incomplete data.
Hidden Markov Models
We observe emissions but the underlying states are hidden. In speech recognition, we hear sounds but don't directly observe phoneme states.
Semi-supervised Learning
Some data points have labels, others don't. The missing labels are latent variables we want to infer while learning the model.
The core challenge is that the complete-data likelihood (if we knew all the hidden information) would be easy to maximize, but we only have access to the observed-data likelihood, which marginalizes over the hidden variables and is often intractable.
The Likelihood Challenge
Complete-data log-likelihood (if we knew Z):
Often simple — exponential family structure
Observed-data log-likelihood (marginalizing Z):
Often intractable — log of sum is hard!
EM Algorithm Intuition
The Chicken-and-Egg Problem
Consider clustering data with a Gaussian Mixture Model. We face a fundamental dilemma:
If we knew the assignments...
Estimating cluster parameters (means, variances) would be trivial — just compute sample statistics for each cluster separately!
If we knew the parameters...
Assigning points to clusters would be easy — just assign each point to the cluster with highest probability!
But we know neither! 🐔🥚
The EM Solution
EM breaks this circularity by alternating between two steps:
EExpectation Step
Given current parameter estimates, compute the expectedvalues of the latent variables. Instead of hard assignments, we compute soft responsibilities — probabilities that each point belongs to each cluster.
MMaximization Step
Given the expected latent variables, update the parameters by maximizing the expected complete-data log-likelihood. This is usually a weighted MLE problem with closed-form solutions.
Key Insight: EM turns an intractable optimization into a sequence of tractable ones. By "filling in" the missing data with its expected value, we can use standard MLE techniques. The magic is that this process is guaranteed to converge to a local maximum!
The EM Algorithm
EM Algorithm
Input: Observed data X, initial parameters θ⁽⁰⁾
Repeat until convergence:
E-Step:
Compute
M-Step:
Set
Output: Converged parameters θ*
E-Step: Expectation
The E-step computes the Q-function: the expected complete-data log-likelihood, where the expectation is taken over the conditional distribution of the latent variables Z given the observed data X and current parameters θ⁽ᵗ⁾.
For continuous Z, replace sum with integral.
In practice, this often means computing posterior responsibilities: the probability that each data point belongs to each latent class, given current parameters.
M-Step: Maximization
The M-step maximizes the Q-function over θ. Because we've taken the expectation over Z, this is typically a weighted maximum likelihood problem that often has closed-form solutions.
Convergence Properties
EM has beautiful theoretical properties:
| Property | Description |
|---|---|
| Monotonicity | Each iteration increases (or maintains) the observed-data log-likelihood: ℓ(θ⁽ᵗ⁺¹⁾) ≥ ℓ(θ⁽ᵗ⁾) |
| Convergence | Under regularity conditions, EM converges to a stationary point of the likelihood |
| Local Maximum | The limit point is typically a local maximum (but not necessarily global!) |
| Linear Rate | Convergence is typically linear; can be slow near the optimum |
Canonical Example: Gaussian Mixture Models
The Gaussian Mixture Model (GMM) is the classic example for EM. Let's derive the algorithm step by step.
GMM Setup
A K-component Gaussian mixture assumes data comes from K Gaussian distributions:
GMM Generative Process
For each data point i = 1, ..., n:
- Draw cluster assignment:
- Draw observation:
The parameters are :
- Mixing proportions: πₖ = P(zᵢ = k), where Σπₖ = 1
- Component means: μₖ ∈ ℝᵈ
- Component covariances: Σₖ ∈ ℝᵈˣᵈ (positive definite)
The latent variables are the cluster assignments z₁, ..., zₙ. If we knew them, MLE would be trivial!
E-Step for GMM
We compute the posterior responsibility — the probability that point i belongs to cluster k, given current parameters:
E-Step: Compute Responsibilities
This is just Bayes' rule: (prior × likelihood) / evidence
The responsibility γᵢₖ represents a "soft assignment" — instead of saying "point i belongs to cluster k", we say "point i belongs to cluster k with probability γᵢₖ".
M-Step for GMM
Given responsibilities, we update parameters using weighted MLEs:
M-Step: Update Parameters
Effective count for cluster k:
Mixing proportions:
Component means (weighted average):
Component covariances (weighted sample covariance):
Notice how similar these are to standard MLEs — just with weights γᵢₖ instead of hard assignments!
Theoretical Foundation
Evidence Lower Bound (ELBO)
Why does EM work? The key insight is that EM implicitly maximizes a lower bound on the log-likelihood, called the Evidence Lower Bound (ELBO).
The ELBO Decomposition
Since KL divergence ≥ 0, the ELBO is always ≤ log-likelihood
The EM algorithm can be understood as coordinate ascent on the ELBO:
| Step | Action | Effect on ELBO |
|---|---|---|
| E-Step | Set q(Z) = p(Z|X,θ⁽ᵗ⁾) | KL = 0, so ELBO = log-likelihood (tight bound) |
| M-Step | Maximize ELBO over θ | Increases ELBO (and thus log-likelihood) |
Why EM Always Improves
Here's the key guarantee:
EM Monotonicity Theorem
Each EM iteration increases (or maintains) the observed-data log-likelihood.
Connection to Rao-Blackwell
The EM algorithm has a beautiful connection to the Rao-Blackwell theorem we studied in the previous section:
EM as Iterative Rao-Blackwell
Rao-Blackwell: Given any unbiased estimator T, conditioning on sufficient statistic S gives improved estimator E[T|S] with lower variance.
EM: Given current parameters θ⁽ᵗ⁾, the E-step computes conditional expectations E[·|X, θ⁽ᵗ⁾], effectively "Rao-Blackwellizing" the complete-data log-likelihood. The M-step then optimizes this improved objective.
Both techniques use conditional expectations to remove "noise" and improve estimation!
EM Variants
Generalized EM (GEM)
Instead of maximizing Q in the M-step, just find θ⁽ᵗ⁺¹⁾ such that Q(θ⁽ᵗ⁺¹⁾|θ⁽ᵗ⁾) > Q(θ⁽ᵗ⁾|θ⁽ᵗ⁾). Useful when M-step has no closed form.
Variational EM
When E-step is intractable, approximate p(Z|X,θ) with a simpler distribution q(Z). Foundation of Variational Autoencoders (VAEs)!
Monte Carlo EM (MCEM)
When expectations in E-step are intractable, use Monte Carlo sampling to approximate them. Trade-off: more samples = better approximation but slower.
Stochastic EM
Instead of computing expectations, sample Z from p(Z|X,θ⁽ᵗ⁾) and use these samples in M-step. Adds noise but can help escape local optima.
AI/ML Applications
🎨 Variational Autoencoders (VAEs)
VAEs are essentially amortized variational EM. Instead of running E-step per data point, an encoder network predicts the approximate posterior q(z|x). The M-step becomes training the decoder network. The ELBO objective comes directly from EM theory!
Topic Modeling (LDA)
Latent Dirichlet Allocation uses variational EM to discover topics in documents. Latent variables are topic assignments for each word.
Hidden Markov Models
The Baum-Welch algorithm is EM for HMMs. E-step computes forward-backward probabilities; M-step updates transition and emission probabilities.
Semi-supervised Learning
EM naturally handles partially labeled data: known labels are observed, missing labels are latent variables. E-step imputes missing labels.
Image Segmentation
Treating pixel cluster assignments as latent variables, EM-based GMM provides principled segmentation with uncertainty quantification.
Python Implementation
1import numpy as np
2from scipy.stats import multivariate_normal
3from typing import Tuple, List
4import warnings
5
6class GaussianMixtureEM:
7 """
8 Gaussian Mixture Model fitted with EM algorithm.
9
10 Implementation from scratch for educational purposes.
11 """
12
13 def __init__(
14 self,
15 n_components: int = 3,
16 max_iter: int = 100,
17 tol: float = 1e-6,
18 random_state: int = 42
19 ):
20 """
21 Initialize GMM.
22
23 Parameters
24 ----------
25 n_components : int
26 Number of mixture components (clusters)
27 max_iter : int
28 Maximum EM iterations
29 tol : float
30 Convergence tolerance for log-likelihood
31 random_state : int
32 Random seed for reproducibility
33 """
34 self.n_components = n_components
35 self.max_iter = max_iter
36 self.tol = tol
37 self.random_state = random_state
38
39 # Parameters (initialized in fit)
40 self.weights_ = None # π_k: mixing proportions
41 self.means_ = None # μ_k: component means
42 self.covariances_ = None # Σ_k: component covariances
43
44 # Fitting results
45 self.log_likelihoods_ = []
46 self.converged_ = False
47 self.n_iter_ = 0
48
49 def _initialize_parameters(self, X: np.ndarray):
50 """Initialize parameters using k-means++ style initialization."""
51 np.random.seed(self.random_state)
52 n_samples, n_features = X.shape
53
54 # Initialize means using k-means++ style
55 indices = np.random.choice(n_samples, self.n_components, replace=False)
56 self.means_ = X[indices].copy()
57
58 # Initialize covariances as identity matrices
59 self.covariances_ = np.array([
60 np.eye(n_features) for _ in range(self.n_components)
61 ])
62
63 # Initialize equal mixing proportions
64 self.weights_ = np.ones(self.n_components) / self.n_components
65
66 def _e_step(self, X: np.ndarray) -> np.ndarray:
67 """
68 E-Step: Compute responsibilities (posterior probabilities).
69
70 γ_ik = P(z_i = k | x_i, θ) = π_k N(x_i|μ_k,Σ_k) / Σ_j π_j N(x_i|μ_j,Σ_j)
71
72 Parameters
73 ----------
74 X : array of shape (n_samples, n_features)
75
76 Returns
77 -------
78 responsibilities : array of shape (n_samples, n_components)
79 γ_ik = probability that point i belongs to component k
80 """
81 n_samples = X.shape[0]
82
83 # Compute weighted probabilities for each component
84 weighted_probs = np.zeros((n_samples, self.n_components))
85
86 for k in range(self.n_components):
87 # Multivariate normal density
88 try:
89 rv = multivariate_normal(
90 mean=self.means_[k],
91 cov=self.covariances_[k],
92 allow_singular=True
93 )
94 weighted_probs[:, k] = self.weights_[k] * rv.pdf(X)
95 except:
96 # Handle numerical issues
97 weighted_probs[:, k] = 1e-10
98
99 # Normalize to get responsibilities (Bayes' rule)
100 total = weighted_probs.sum(axis=1, keepdims=True)
101 total = np.maximum(total, 1e-10) # Prevent division by zero
102 responsibilities = weighted_probs / total
103
104 return responsibilities
105
106 def _m_step(self, X: np.ndarray, responsibilities: np.ndarray):
107 """
108 M-Step: Update parameters using weighted MLE.
109
110 Parameters
111 ----------
112 X : array of shape (n_samples, n_features)
113 responsibilities : array of shape (n_samples, n_components)
114 """
115 n_samples, n_features = X.shape
116
117 for k in range(self.n_components):
118 # Effective number of points assigned to component k
119 N_k = responsibilities[:, k].sum()
120 N_k = max(N_k, 1e-10) # Prevent division by zero
121
122 # Update mixing proportion: π_k = N_k / n
123 self.weights_[k] = N_k / n_samples
124
125 # Update mean: μ_k = (1/N_k) Σ γ_ik x_i
126 self.means_[k] = (responsibilities[:, k:k+1].T @ X) / N_k
127 self.means_[k] = self.means_[k].flatten()
128
129 # Update covariance: Σ_k = (1/N_k) Σ γ_ik (x_i - μ_k)(x_i - μ_k)^T
130 diff = X - self.means_[k]
131 weighted_diff = responsibilities[:, k:k+1] * diff
132 self.covariances_[k] = (weighted_diff.T @ diff) / N_k
133
134 # Add small regularization for numerical stability
135 self.covariances_[k] += 1e-6 * np.eye(n_features)
136
137 def _compute_log_likelihood(self, X: np.ndarray) -> float:
138 """Compute observed-data log-likelihood."""
139 n_samples = X.shape[0]
140 log_likelihood = 0.0
141
142 for i in range(n_samples):
143 point_likelihood = 0.0
144 for k in range(self.n_components):
145 try:
146 rv = multivariate_normal(
147 mean=self.means_[k],
148 cov=self.covariances_[k],
149 allow_singular=True
150 )
151 point_likelihood += self.weights_[k] * rv.pdf(X[i])
152 except:
153 point_likelihood += 1e-10
154
155 log_likelihood += np.log(max(point_likelihood, 1e-10))
156
157 return log_likelihood
158
159 def fit(self, X: np.ndarray) -> 'GaussianMixtureEM':
160 """
161 Fit GMM using EM algorithm.
162
163 Parameters
164 ----------
165 X : array of shape (n_samples, n_features)
166
167 Returns
168 -------
169 self
170 """
171 # Initialize parameters
172 self._initialize_parameters(X)
173
174 # Initial log-likelihood
175 prev_ll = self._compute_log_likelihood(X)
176 self.log_likelihoods_ = [prev_ll]
177
178 # EM iterations
179 for iteration in range(self.max_iter):
180 # E-Step: compute responsibilities
181 responsibilities = self._e_step(X)
182
183 # M-Step: update parameters
184 self._m_step(X, responsibilities)
185
186 # Compute new log-likelihood
187 curr_ll = self._compute_log_likelihood(X)
188 self.log_likelihoods_.append(curr_ll)
189
190 # Check convergence
191 ll_change = curr_ll - prev_ll
192
193 if ll_change < 0:
194 warnings.warn(f"Log-likelihood decreased at iteration {iteration}")
195
196 if abs(ll_change) < self.tol:
197 self.converged_ = True
198 self.n_iter_ = iteration + 1
199 break
200
201 prev_ll = curr_ll
202
203 if not self.converged_:
204 self.n_iter_ = self.max_iter
205
206 return self
207
208 def predict(self, X: np.ndarray) -> np.ndarray:
209 """Predict cluster labels (hard assignment)."""
210 responsibilities = self._e_step(X)
211 return np.argmax(responsibilities, axis=1)
212
213 def predict_proba(self, X: np.ndarray) -> np.ndarray:
214 """Predict cluster probabilities (soft assignment)."""
215 return self._e_step(X)
216
217
218# ============================================
219# Demonstration
220# ============================================
221
222def generate_gmm_data(
223 n_samples: int = 500,
224 n_components: int = 3,
225 random_state: int = 42
226) -> Tuple[np.ndarray, np.ndarray]:
227 """Generate synthetic GMM data."""
228 np.random.seed(random_state)
229
230 # True parameters
231 true_means = np.array([
232 [0, 0],
233 [4, 4],
234 [0, 5]
235 ])
236 true_covs = [
237 np.array([[1, 0.5], [0.5, 1]]),
238 np.array([[1.5, 0], [0, 0.5]]),
239 np.array([[0.5, 0], [0, 1.5]])
240 ]
241 true_weights = [0.3, 0.4, 0.3]
242
243 # Generate data
244 X = []
245 y = []
246
247 for i in range(n_samples):
248 # Sample component
249 k = np.random.choice(n_components, p=true_weights)
250 # Sample from component
251 x = np.random.multivariate_normal(true_means[k], true_covs[k])
252 X.append(x)
253 y.append(k)
254
255 return np.array(X), np.array(y)
256
257
258if __name__ == "__main__":
259 print("=" * 60)
260 print("EM ALGORITHM FOR GAUSSIAN MIXTURE MODELS")
261 print("=" * 60)
262
263 # Generate data
264 X, true_labels = generate_gmm_data(n_samples=500)
265 print(f"\nGenerated {len(X)} samples from 3-component GMM")
266
267 # Fit GMM with EM
268 gmm = GaussianMixtureEM(n_components=3, max_iter=100)
269 gmm.fit(X)
270
271 print(f"\nConverged: {gmm.converged_} in {gmm.n_iter_} iterations")
272 print(f"Final log-likelihood: {gmm.log_likelihoods_[-1]:.2f}")
273
274 # Show learned parameters
275 print("\nLearned mixing proportions:")
276 for k, w in enumerate(gmm.weights_):
277 print(f" Component {k}: {w:.3f}")
278
279 print("\nLearned means:")
280 for k, mu in enumerate(gmm.means_):
281 print(f" Component {k}: [{mu[0]:.2f}, {mu[1]:.2f}]")
282
283 # Predict labels
284 pred_labels = gmm.predict(X)
285
286 # Note: Label matching requires permutation due to label switching
287 print("\nPredicted label distribution:")
288 for k in range(3):
289 print(f" Cluster {k}: {(pred_labels == k).sum()} points")
290
291 # Show likelihood progression
292 print("\nLog-likelihood progression (first 10 iterations):")
293 for i, ll in enumerate(gmm.log_likelihoods_[:10]):
294 print(f" Iteration {i}: {ll:.2f}")Common Pitfalls
Summary
Key Takeaways
- The EM algorithm is a general method for finding maximum likelihood estimates when data has latent (hidden) variables ormissing values.
- E-step computes the expected complete-data log-likelihood given current parameters — essentially "filling in" missing data with expected values.
- M-step maximizes this expectation over parameters — typically a weighted MLE problem with closed-form solutions.
- EM always increases (or maintains) the observed-data likelihood, guaranteeing convergence to a local maximum.
- Connection to ELBO: EM is coordinate ascent on the Evidence Lower Bound, making it a special case of variational inference.
- Deep learning applications: VAEs, topic models, HMMs, and semi-supervised learning all use EM or EM-inspired techniques.
Looking Ahead: In Chapter 13, we'll explore Interval Estimation — moving beyond point estimates to quantify uncertainty through confidence intervals, bootstrap methods, and Bayesian credible intervals.