Chapter 12
35 min read
Section 85 of 175

EM Algorithm

Methods of Estimation

Learning Objectives

By the end of this section, you will be able to:

📚 Core Knowledge

  • • Explain the E-step and M-step of the EM algorithm
  • • Derive the EM updates for Gaussian Mixture Models
  • • Understand why EM always increases likelihood
  • • Connect EM to the Evidence Lower Bound (ELBO)

🔧 Practical Skills

  • • Implement EM for mixture models from scratch
  • • Diagnose convergence issues in EM
  • • Choose appropriate initialization strategies
  • • Apply EM to real-world clustering problems

🧠 Deep Learning Connections

  • Variational Autoencoders (VAEs): EM inspiration for amortized inference
  • Latent Variable Models: Foundation for generative modeling
  • Semi-supervised Learning: Handling partially labeled data
  • Missing Data Imputation: Principled approaches to incomplete datasets
Where You'll Apply This: Clustering with GMM, topic modeling (LDA), hidden Markov models, image segmentation, semi-supervised learning, and any scenario with latent (hidden) variables or missing data.

The Big Picture

The Expectation-Maximization (EM) algorithm, introduced by Dempster, Laird, and Rubin in their landmark 1977 paper, is one of the most influential algorithms in statistics and machine learning. It provides an elegant solution to a fundamental problem: How do we find maximum likelihood estimates when some data is missing or hidden?

Historical Context

While the specific techniques existed earlier, Dempster, Laird & Rubin (1977) unified them under the EM framework, proving convergence guarantees and revealing deep connections to information theory. The paper has been cited over 70,000 times!

EM is particularly powerful because it:

  • Handles latent variables: Hidden cluster assignments, missing observations, unobserved states
  • Always improves: Each iteration is guaranteed to increase (or maintain) the likelihood
  • Simple to implement: Often reduces complex optimization to closed-form updates
  • Connects theory and practice: Bridges MLE, Bayesian inference, and information theory

The Missing Data Problem

Consider a dataset where some information is hidden or latent. This happens in many real scenarios:

Mixture Models

We observe data points but don't know which cluster (component) each point belongs to. The cluster assignments are latent variables.

Missing Values

Some entries in our dataset are missing (sensors failed, survey questions skipped). We want to estimate parameters despite incomplete data.

Hidden Markov Models

We observe emissions but the underlying states are hidden. In speech recognition, we hear sounds but don't directly observe phoneme states.

Semi-supervised Learning

Some data points have labels, others don't. The missing labels are latent variables we want to infer while learning the model.

The core challenge is that the complete-data likelihood (if we knew all the hidden information) would be easy to maximize, but we only have access to the observed-data likelihood, which marginalizes over the hidden variables and is often intractable.

The Likelihood Challenge

Complete-data log-likelihood (if we knew Z):

logL(θ;X,Z)=logp(X,Zθ)\log L(\theta; X, Z) = \log p(X, Z | \theta)

Often simple — exponential family structure

Observed-data log-likelihood (marginalizing Z):

logL(θ;X)=logZp(X,Zθ)\log L(\theta; X) = \log \sum_Z p(X, Z | \theta)

Often intractable — log of sum is hard!


EM Algorithm Intuition

The Chicken-and-Egg Problem

Consider clustering data with a Gaussian Mixture Model. We face a fundamental dilemma:

If we knew the assignments...

Estimating cluster parameters (means, variances) would be trivial — just compute sample statistics for each cluster separately!

If we knew the parameters...

Assigning points to clusters would be easy — just assign each point to the cluster with highest probability!

But we know neither! 🐔🥚

The EM Solution

EM breaks this circularity by alternating between two steps:

EExpectation Step

Given current parameter estimates, compute the expectedvalues of the latent variables. Instead of hard assignments, we compute soft responsibilities — probabilities that each point belongs to each cluster.

Q(θθ(t))=EZX,θ(t)[logp(X,Zθ)]Q(\theta | \theta^{(t)}) = E_{Z|X,\theta^{(t)}}[\log p(X, Z | \theta)]

MMaximization Step

Given the expected latent variables, update the parameters by maximizing the expected complete-data log-likelihood. This is usually a weighted MLE problem with closed-form solutions.

θ(t+1)=argmaxθQ(θθ(t))\theta^{(t+1)} = \arg\max_\theta Q(\theta | \theta^{(t)})
Key Insight: EM turns an intractable optimization into a sequence of tractable ones. By "filling in" the missing data with its expected value, we can use standard MLE techniques. The magic is that this process is guaranteed to converge to a local maximum!

The EM Algorithm

EM Algorithm

Input: Observed data X, initial parameters θ⁽⁰⁾

Repeat until convergence:

E-Step:

Compute Q(θθ(t))=EZX,θ(t)[logp(X,Zθ)]Q(\theta | \theta^{(t)}) = E_{Z|X,\theta^{(t)}}[\log p(X, Z | \theta)]

M-Step:

Set θ(t+1)=argmaxθQ(θθ(t))\theta^{(t+1)} = \arg\max_\theta Q(\theta | \theta^{(t)})

Output: Converged parameters θ*

E-Step: Expectation

The E-step computes the Q-function: the expected complete-data log-likelihood, where the expectation is taken over the conditional distribution of the latent variables Z given the observed data X and current parameters θ⁽ᵗ⁾.

Q(θθ(t))=Zp(ZX,θ(t))logp(X,Zθ)Q(\theta | \theta^{(t)}) = \sum_Z p(Z | X, \theta^{(t)}) \log p(X, Z | \theta)

For continuous Z, replace sum with integral.

In practice, this often means computing posterior responsibilities: the probability that each data point belongs to each latent class, given current parameters.

M-Step: Maximization

The M-step maximizes the Q-function over θ. Because we've taken the expectation over Z, this is typically a weighted maximum likelihood problem that often has closed-form solutions.

θ(t+1)=argmaxθQ(θθ(t))\theta^{(t+1)} = \arg\max_\theta Q(\theta | \theta^{(t)})
Why M-Step is Often Easy: The complete-data likelihood usually belongs to an exponential family, for which MLE has closed-form solutions. The E-step just provides "soft counts" or weights for these solutions.

Convergence Properties

EM has beautiful theoretical properties:

PropertyDescription
MonotonicityEach iteration increases (or maintains) the observed-data log-likelihood: ℓ(θ⁽ᵗ⁺¹⁾) ≥ ℓ(θ⁽ᵗ⁾)
ConvergenceUnder regularity conditions, EM converges to a stationary point of the likelihood
Local MaximumThe limit point is typically a local maximum (but not necessarily global!)
Linear RateConvergence is typically linear; can be slow near the optimum
EM finds local, not global, optima! Different initializations may converge to different solutions. For mixture models, it's common practice to run EM multiple times with different random initializations and keep the best result.

Canonical Example: Gaussian Mixture Models

The Gaussian Mixture Model (GMM) is the classic example for EM. Let's derive the algorithm step by step.

GMM Setup

A K-component Gaussian mixture assumes data comes from K Gaussian distributions:

GMM Generative Process

For each data point i = 1, ..., n:

  1. Draw cluster assignment: ziCategorical(π1,,πK)z_i \sim \text{Categorical}(\pi_1, \ldots, \pi_K)
  2. Draw observation: xizi=kN(μk,Σk)x_i | z_i = k \sim \mathcal{N}(\mu_k, \Sigma_k)

The parameters are θ={π1,,πK,μ1,,μK,Σ1,,ΣK}\theta = \{\pi_1, \ldots, \pi_K, \mu_1, \ldots, \mu_K, \Sigma_1, \ldots, \Sigma_K\}:

  • Mixing proportions: πₖ = P(zᵢ = k), where Σπₖ = 1
  • Component means: μₖ ∈ ℝᵈ
  • Component covariances: Σₖ ∈ ℝᵈˣᵈ (positive definite)

The latent variables are the cluster assignments z₁, ..., zₙ. If we knew them, MLE would be trivial!

E-Step for GMM

We compute the posterior responsibility — the probability that point i belongs to cluster k, given current parameters:

E-Step: Compute Responsibilities

γik=p(zi=kxi,θ(t))=πk(t)N(xiμk(t),Σk(t))j=1Kπj(t)N(xiμj(t),Σj(t))\gamma_{ik} = p(z_i = k | x_i, \theta^{(t)}) = \frac{\pi_k^{(t)} \mathcal{N}(x_i | \mu_k^{(t)}, \Sigma_k^{(t)})}{\sum_{j=1}^K \pi_j^{(t)} \mathcal{N}(x_i | \mu_j^{(t)}, \Sigma_j^{(t)})}

This is just Bayes' rule: (prior × likelihood) / evidence

The responsibility γᵢₖ represents a "soft assignment" — instead of saying "point i belongs to cluster k", we say "point i belongs to cluster k with probability γᵢₖ".

M-Step for GMM

Given responsibilities, we update parameters using weighted MLEs:

M-Step: Update Parameters

Effective count for cluster k:

Nk=i=1nγikN_k = \sum_{i=1}^n \gamma_{ik}

Mixing proportions:

πk(t+1)=Nkn\pi_k^{(t+1)} = \frac{N_k}{n}

Component means (weighted average):

μk(t+1)=1Nki=1nγikxi\mu_k^{(t+1)} = \frac{1}{N_k} \sum_{i=1}^n \gamma_{ik} x_i

Component covariances (weighted sample covariance):

Σk(t+1)=1Nki=1nγik(xiμk(t+1))(xiμk(t+1))T\Sigma_k^{(t+1)} = \frac{1}{N_k} \sum_{i=1}^n \gamma_{ik} (x_i - \mu_k^{(t+1)})(x_i - \mu_k^{(t+1)})^T

Notice how similar these are to standard MLEs — just with weights γᵢₖ instead of hard assignments!


Theoretical Foundation

Evidence Lower Bound (ELBO)

Why does EM work? The key insight is that EM implicitly maximizes a lower bound on the log-likelihood, called the Evidence Lower Bound (ELBO).

The ELBO Decomposition

logp(Xθ)=L(q,θ)+KL(q(Z)p(ZX,θ))\log p(X|\theta) = \mathcal{L}(q, \theta) + \text{KL}(q(Z) \| p(Z|X,\theta))
L(q,θ)=Eq[logp(X,Zθ)]Eq[logq(Z)]\mathcal{L}(q, \theta) = E_q[\log p(X, Z|\theta)] - E_q[\log q(Z)]

Since KL divergence ≥ 0, the ELBO is always ≤ log-likelihood

The EM algorithm can be understood as coordinate ascent on the ELBO:

StepActionEffect on ELBO
E-StepSet q(Z) = p(Z|X,θ⁽ᵗ⁾)KL = 0, so ELBO = log-likelihood (tight bound)
M-StepMaximize ELBO over θIncreases ELBO (and thus log-likelihood)

Why EM Always Improves

Here's the key guarantee:

EM Monotonicity Theorem

logp(Xθ(t+1))logp(Xθ(t))\log p(X|\theta^{(t+1)}) \geq \log p(X|\theta^{(t)})

Each EM iteration increases (or maintains) the observed-data log-likelihood.

Connection to Rao-Blackwell

The EM algorithm has a beautiful connection to the Rao-Blackwell theorem we studied in the previous section:

EM as Iterative Rao-Blackwell

Rao-Blackwell: Given any unbiased estimator T, conditioning on sufficient statistic S gives improved estimator E[T|S] with lower variance.

EM: Given current parameters θ⁽ᵗ⁾, the E-step computes conditional expectations E[·|X, θ⁽ᵗ⁾], effectively "Rao-Blackwellizing" the complete-data log-likelihood. The M-step then optimizes this improved objective.

Both techniques use conditional expectations to remove "noise" and improve estimation!


EM Variants

Generalized EM (GEM)

Instead of maximizing Q in the M-step, just find θ⁽ᵗ⁺¹⁾ such that Q(θ⁽ᵗ⁺¹⁾|θ⁽ᵗ⁾) > Q(θ⁽ᵗ⁾|θ⁽ᵗ⁾). Useful when M-step has no closed form.

Variational EM

When E-step is intractable, approximate p(Z|X,θ) with a simpler distribution q(Z). Foundation of Variational Autoencoders (VAEs)!

Monte Carlo EM (MCEM)

When expectations in E-step are intractable, use Monte Carlo sampling to approximate them. Trade-off: more samples = better approximation but slower.

Stochastic EM

Instead of computing expectations, sample Z from p(Z|X,θ⁽ᵗ⁾) and use these samples in M-step. Adds noise but can help escape local optima.


AI/ML Applications

🎨 Variational Autoencoders (VAEs)

VAEs are essentially amortized variational EM. Instead of running E-step per data point, an encoder network predicts the approximate posterior q(z|x). The M-step becomes training the decoder network. The ELBO objective comes directly from EM theory!

L=Eq(zx)[logp(xz)]KL(q(zx)p(z))\mathcal{L} = E_{q(z|x)}[\log p(x|z)] - \text{KL}(q(z|x) \| p(z))

Topic Modeling (LDA)

Latent Dirichlet Allocation uses variational EM to discover topics in documents. Latent variables are topic assignments for each word.

Hidden Markov Models

The Baum-Welch algorithm is EM for HMMs. E-step computes forward-backward probabilities; M-step updates transition and emission probabilities.

Semi-supervised Learning

EM naturally handles partially labeled data: known labels are observed, missing labels are latent variables. E-step imputes missing labels.

Image Segmentation

Treating pixel cluster assignments as latent variables, EM-based GMM provides principled segmentation with uncertainty quantification.


Python Implementation

🐍python
1import numpy as np
2from scipy.stats import multivariate_normal
3from typing import Tuple, List
4import warnings
5
6class GaussianMixtureEM:
7    """
8    Gaussian Mixture Model fitted with EM algorithm.
9
10    Implementation from scratch for educational purposes.
11    """
12
13    def __init__(
14        self,
15        n_components: int = 3,
16        max_iter: int = 100,
17        tol: float = 1e-6,
18        random_state: int = 42
19    ):
20        """
21        Initialize GMM.
22
23        Parameters
24        ----------
25        n_components : int
26            Number of mixture components (clusters)
27        max_iter : int
28            Maximum EM iterations
29        tol : float
30            Convergence tolerance for log-likelihood
31        random_state : int
32            Random seed for reproducibility
33        """
34        self.n_components = n_components
35        self.max_iter = max_iter
36        self.tol = tol
37        self.random_state = random_state
38
39        # Parameters (initialized in fit)
40        self.weights_ = None      # π_k: mixing proportions
41        self.means_ = None        # μ_k: component means
42        self.covariances_ = None  # Σ_k: component covariances
43
44        # Fitting results
45        self.log_likelihoods_ = []
46        self.converged_ = False
47        self.n_iter_ = 0
48
49    def _initialize_parameters(self, X: np.ndarray):
50        """Initialize parameters using k-means++ style initialization."""
51        np.random.seed(self.random_state)
52        n_samples, n_features = X.shape
53
54        # Initialize means using k-means++ style
55        indices = np.random.choice(n_samples, self.n_components, replace=False)
56        self.means_ = X[indices].copy()
57
58        # Initialize covariances as identity matrices
59        self.covariances_ = np.array([
60            np.eye(n_features) for _ in range(self.n_components)
61        ])
62
63        # Initialize equal mixing proportions
64        self.weights_ = np.ones(self.n_components) / self.n_components
65
66    def _e_step(self, X: np.ndarray) -> np.ndarray:
67        """
68        E-Step: Compute responsibilities (posterior probabilities).
69
70        γ_ik = P(z_i = k | x_i, θ) = π_k N(x_i|μ_k,Σ_k) / Σ_j π_j N(x_i|μ_j,Σ_j)
71
72        Parameters
73        ----------
74        X : array of shape (n_samples, n_features)
75
76        Returns
77        -------
78        responsibilities : array of shape (n_samples, n_components)
79            γ_ik = probability that point i belongs to component k
80        """
81        n_samples = X.shape[0]
82
83        # Compute weighted probabilities for each component
84        weighted_probs = np.zeros((n_samples, self.n_components))
85
86        for k in range(self.n_components):
87            # Multivariate normal density
88            try:
89                rv = multivariate_normal(
90                    mean=self.means_[k],
91                    cov=self.covariances_[k],
92                    allow_singular=True
93                )
94                weighted_probs[:, k] = self.weights_[k] * rv.pdf(X)
95            except:
96                # Handle numerical issues
97                weighted_probs[:, k] = 1e-10
98
99        # Normalize to get responsibilities (Bayes' rule)
100        total = weighted_probs.sum(axis=1, keepdims=True)
101        total = np.maximum(total, 1e-10)  # Prevent division by zero
102        responsibilities = weighted_probs / total
103
104        return responsibilities
105
106    def _m_step(self, X: np.ndarray, responsibilities: np.ndarray):
107        """
108        M-Step: Update parameters using weighted MLE.
109
110        Parameters
111        ----------
112        X : array of shape (n_samples, n_features)
113        responsibilities : array of shape (n_samples, n_components)
114        """
115        n_samples, n_features = X.shape
116
117        for k in range(self.n_components):
118            # Effective number of points assigned to component k
119            N_k = responsibilities[:, k].sum()
120            N_k = max(N_k, 1e-10)  # Prevent division by zero
121
122            # Update mixing proportion: π_k = N_k / n
123            self.weights_[k] = N_k / n_samples
124
125            # Update mean: μ_k = (1/N_k) Σ γ_ik x_i
126            self.means_[k] = (responsibilities[:, k:k+1].T @ X) / N_k
127            self.means_[k] = self.means_[k].flatten()
128
129            # Update covariance: Σ_k = (1/N_k) Σ γ_ik (x_i - μ_k)(x_i - μ_k)^T
130            diff = X - self.means_[k]
131            weighted_diff = responsibilities[:, k:k+1] * diff
132            self.covariances_[k] = (weighted_diff.T @ diff) / N_k
133
134            # Add small regularization for numerical stability
135            self.covariances_[k] += 1e-6 * np.eye(n_features)
136
137    def _compute_log_likelihood(self, X: np.ndarray) -> float:
138        """Compute observed-data log-likelihood."""
139        n_samples = X.shape[0]
140        log_likelihood = 0.0
141
142        for i in range(n_samples):
143            point_likelihood = 0.0
144            for k in range(self.n_components):
145                try:
146                    rv = multivariate_normal(
147                        mean=self.means_[k],
148                        cov=self.covariances_[k],
149                        allow_singular=True
150                    )
151                    point_likelihood += self.weights_[k] * rv.pdf(X[i])
152                except:
153                    point_likelihood += 1e-10
154
155            log_likelihood += np.log(max(point_likelihood, 1e-10))
156
157        return log_likelihood
158
159    def fit(self, X: np.ndarray) -> 'GaussianMixtureEM':
160        """
161        Fit GMM using EM algorithm.
162
163        Parameters
164        ----------
165        X : array of shape (n_samples, n_features)
166
167        Returns
168        -------
169        self
170        """
171        # Initialize parameters
172        self._initialize_parameters(X)
173
174        # Initial log-likelihood
175        prev_ll = self._compute_log_likelihood(X)
176        self.log_likelihoods_ = [prev_ll]
177
178        # EM iterations
179        for iteration in range(self.max_iter):
180            # E-Step: compute responsibilities
181            responsibilities = self._e_step(X)
182
183            # M-Step: update parameters
184            self._m_step(X, responsibilities)
185
186            # Compute new log-likelihood
187            curr_ll = self._compute_log_likelihood(X)
188            self.log_likelihoods_.append(curr_ll)
189
190            # Check convergence
191            ll_change = curr_ll - prev_ll
192
193            if ll_change < 0:
194                warnings.warn(f"Log-likelihood decreased at iteration {iteration}")
195
196            if abs(ll_change) < self.tol:
197                self.converged_ = True
198                self.n_iter_ = iteration + 1
199                break
200
201            prev_ll = curr_ll
202
203        if not self.converged_:
204            self.n_iter_ = self.max_iter
205
206        return self
207
208    def predict(self, X: np.ndarray) -> np.ndarray:
209        """Predict cluster labels (hard assignment)."""
210        responsibilities = self._e_step(X)
211        return np.argmax(responsibilities, axis=1)
212
213    def predict_proba(self, X: np.ndarray) -> np.ndarray:
214        """Predict cluster probabilities (soft assignment)."""
215        return self._e_step(X)
216
217
218# ============================================
219# Demonstration
220# ============================================
221
222def generate_gmm_data(
223    n_samples: int = 500,
224    n_components: int = 3,
225    random_state: int = 42
226) -> Tuple[np.ndarray, np.ndarray]:
227    """Generate synthetic GMM data."""
228    np.random.seed(random_state)
229
230    # True parameters
231    true_means = np.array([
232        [0, 0],
233        [4, 4],
234        [0, 5]
235    ])
236    true_covs = [
237        np.array([[1, 0.5], [0.5, 1]]),
238        np.array([[1.5, 0], [0, 0.5]]),
239        np.array([[0.5, 0], [0, 1.5]])
240    ]
241    true_weights = [0.3, 0.4, 0.3]
242
243    # Generate data
244    X = []
245    y = []
246
247    for i in range(n_samples):
248        # Sample component
249        k = np.random.choice(n_components, p=true_weights)
250        # Sample from component
251        x = np.random.multivariate_normal(true_means[k], true_covs[k])
252        X.append(x)
253        y.append(k)
254
255    return np.array(X), np.array(y)
256
257
258if __name__ == "__main__":
259    print("=" * 60)
260    print("EM ALGORITHM FOR GAUSSIAN MIXTURE MODELS")
261    print("=" * 60)
262
263    # Generate data
264    X, true_labels = generate_gmm_data(n_samples=500)
265    print(f"\nGenerated {len(X)} samples from 3-component GMM")
266
267    # Fit GMM with EM
268    gmm = GaussianMixtureEM(n_components=3, max_iter=100)
269    gmm.fit(X)
270
271    print(f"\nConverged: {gmm.converged_} in {gmm.n_iter_} iterations")
272    print(f"Final log-likelihood: {gmm.log_likelihoods_[-1]:.2f}")
273
274    # Show learned parameters
275    print("\nLearned mixing proportions:")
276    for k, w in enumerate(gmm.weights_):
277        print(f"  Component {k}: {w:.3f}")
278
279    print("\nLearned means:")
280    for k, mu in enumerate(gmm.means_):
281        print(f"  Component {k}: [{mu[0]:.2f}, {mu[1]:.2f}]")
282
283    # Predict labels
284    pred_labels = gmm.predict(X)
285
286    # Note: Label matching requires permutation due to label switching
287    print("\nPredicted label distribution:")
288    for k in range(3):
289        print(f"  Cluster {k}: {(pred_labels == k).sum()} points")
290
291    # Show likelihood progression
292    print("\nLog-likelihood progression (first 10 iterations):")
293    for i, ll in enumerate(gmm.log_likelihoods_[:10]):
294        print(f"  Iteration {i}: {ll:.2f}")

Common Pitfalls


Summary

Key Takeaways

  1. The EM algorithm is a general method for finding maximum likelihood estimates when data has latent (hidden) variables ormissing values.
  2. E-step computes the expected complete-data log-likelihood given current parameters — essentially "filling in" missing data with expected values.
  3. M-step maximizes this expectation over parameters — typically a weighted MLE problem with closed-form solutions.
  4. EM always increases (or maintains) the observed-data likelihood, guaranteeing convergence to a local maximum.
  5. Connection to ELBO: EM is coordinate ascent on the Evidence Lower Bound, making it a special case of variational inference.
  6. Deep learning applications: VAEs, topic models, HMMs, and semi-supervised learning all use EM or EM-inspired techniques.
Looking Ahead: In Chapter 13, we'll explore Interval Estimation — moving beyond point estimates to quantify uncertainty through confidence intervals, bootstrap methods, and Bayesian credible intervals.
Loading comments...