Learning Objectives
By the end of this section, you will be able to:
- Understand latent variable models and why direct likelihood computation is often intractable
- Derive the Evidence Lower Bound (ELBO) and understand its two components: reconstruction and regularization
- Apply the reparameterization trick to enable gradient-based optimization through sampling operations
- Connect variational inference to diffusion models and understand how the diffusion ELBO decomposes into per-timestep terms
- Implement a simple VAE in PyTorch using the ELBO objective
The Big Picture: Inference as Optimization
Variational inference emerged from a profound insight: when exact probabilistic inference is computationally intractable, we can turn inference into an optimization problem. Instead of computing the true posterior distribution exactly, we search for the best approximation from a tractable family of distributions.
Historical Context: Variational methods have roots in physics (mean-field theory) and statistical mechanics. The application to machine learning was pioneered by researchers like Geoffrey Hinton and Michael Jordan in the 1990s. The modern era of variational inference began with Kingma and Welling's VAE paper (2013) and Rezende et al.'s work on stochastic variational inference.
The key insight is that we can convert an intractable integration problem into a tractable optimization problem. This is exactly what we need for diffusion models, where we want to learn a generative model without being able to compute exact likelihoods.
| Exact Inference | Variational Inference |
|---|---|
| Compute true posterior p(z|x) | Find approximate q(z) close to p(z|x) |
| Often intractable integration | Tractable optimization |
| Exact but expensive | Approximate but scalable |
| Limited to conjugate models | Works with neural networks |
Latent Variable Models
Latent variable models assume that observed data is generated through unobserved (latent) variables . The generative process is:
The marginal likelihood of the data requires integrating over all possible latent configurations:
Why Latent Variables?
Latent variables serve multiple purposes in generative modeling:
- Expressiveness: Simple distributions in latent space can map to complex distributions in data space
- Disentanglement: Latent dimensions can capture independent factors of variation (e.g., pose, lighting, identity)
- Compression: High-dimensional data can be encoded into lower-dimensional representations
- Generation: New samples can be created by sampling latent codes and decoding them
Real-World Example (Vision): Consider face images. The latent space might capture age, expression, pose, and lighting as separate dimensions. A VAE learns this without supervision!
The Intractability Problem
For learning and inference, we need the posterior distribution:
The denominator (the evidence or marginal likelihood) requires integrating over all possible latent configurations. For continuous latent spaces and nonlinear relationships (like those modeled by neural networks), this integral is almost always intractable.
Why Is This Hard?
Consider a VAE with a 128-dimensional latent space. Computing exactly would require integrating over all points in . Even with Monte Carlo methods, the variance of such estimates would be prohibitive.
| Approach | Problem |
|---|---|
| Exact integration | Impossible for nonlinear models |
| Grid-based methods | Exponential in dimension |
| Naive Monte Carlo | Extremely high variance |
| Importance sampling | Requires good proposal distribution |
This is where variational inference comes to the rescue: instead of computing the true posterior, we approximate it with a simpler distribution that we can work with.
Variational Inference: Core Idea
The core idea of variational inference is to introduce an approximate posterior from a tractable family of distributions (parameterized by ) and optimize it to be as close as possible to the true posterior.
We measure "closeness" using KL divergence (which we covered in the information theory section):
Recall from the previous section that KL divergence is asymmetric. We use (reverse KL) rather than (forward KL) because it leads to a tractable objective.
Key Insight: The reverse KL tends to produce mode-seeking behavior: will concentrate on regions where has high probability, rather than trying to cover all of 's support. This is important for understanding VAE behavior.
Deriving the ELBO
We want to minimize , but this requires knowing , which in turn requires the intractable . Here's the elegant solution:
Start with the definition of KL divergence:
Use Bayes' rule to expand :
Since doesn't depend on , it comes out of the expectation:
Rearranging:
The second term is the Evidence Lower Bound (ELBO):
Since KL divergence is always non-negative, we have the fundamental inequality:
This is why it's called a "lower bound" on the evidence! By maximizing the ELBO, we simultaneously:
- Maximize the log-likelihood (as much as the bound allows)
- Minimize , making our approximation better
ELBO Decomposition
The ELBO has a beautiful interpretation with two competing terms:
Reconstruction Term
measures how well we can reconstruct the input from the latent code. This is the "quality" term that encourages the model to encode useful information in .
- For continuous data with Gaussian likelihood: equivalent to negative MSE
- For binary data with Bernoulli likelihood: equivalent to binary cross-entropy
- Encourages the encoder to preserve information
Regularization (KL) Term
measures how much the approximate posterior deviates from the prior. This term:
- Prevents the encoder from memorizing each data point
- Encourages smooth, regular latent spaces
- Enables generation by sampling from the prior
The Trade-off: If we only maximize reconstruction, the model becomes an autoencoder without generative capability. If we only minimize KL, we get random noise. The ELBO balances these naturally!
Mean-Field Approximation
The most common choice for the variational family is the mean-field approximation, where we assume the latent dimensions are independent:
For VAEs, we typically use a multivariate Gaussian with diagonal covariance:
The encoder neural network outputs the mean and variance for each data point. In practice, we output for numerical stability.
Advantages of Mean-Field
- Closed-form KL: For Gaussian distributions with Gaussian prior, the KL term has an analytical solution
- Simple parameterization: Only need to output mean and variance vectors
- Efficient sampling: Can sample from each dimension independently
The Closed-Form KL for Gaussians
When and (standard normal prior), the KL divergence has a simple form:
This is the famous KL regularization term used in VAEs!
The Reparameterization Trick
There's one remaining challenge: the ELBO requires taking expectations over , which depends on the parameters we want to optimize. We need to estimate gradients through this expectation.
The problem: If we sample directly, this sampling operation is not differentiable! We can't compute .
The solution: The reparameterization trick rewrites the sampling operation as a deterministic transformation of a parameter-free random variable:
Now the randomness is in , which doesn't depend on . The gradients flow through and !
Why This Works
Consider the gradient of the reconstruction term:
The expectation is now over , which doesn't depend on . We can:
- Sample from
- Compute for each sample
- Estimate the gradient with Monte Carlo:
Key Insight for Diffusion: The reparameterization trick is exactly how we train diffusion models! When we sample, we're using the same reparameterization to enable gradient flow.
Connection to Diffusion Models
Diffusion models are a special case of hierarchical variational autoencoders! The key insight is that diffusion models have a fixed, predetermined forward process (the encoder) and only learn the reverse process (the decoder).
The Diffusion ELBO
In diffusion models, the ELBO decomposes into a sum over timesteps:
| Term | Meaning | Diffusion Interpretation |
|---|---|---|
| -log p(x_0|x_1) | Reconstruction from first latent | Final denoising step quality |
| Sum of KL terms | Match reverse to true posterior | Each denoising step matches optimal transition |
| KL(q(x_T|x_0) || p(x_T)) | Match final latent to prior | Noisy image should look like pure noise |
The Simplified Loss
Ho et al. (2020) showed that this complex ELBO simplifies to a remarkably simple objective when parameterized correctly:
This is just predicting the noise that was added! The connection to variational inference explains why this works: we're still maximizing a lower bound on the log-likelihood, just in a computationally convenient form.
Deep Connection: The diffusion model training objective is derived from variational inference principles. Each denoising step learns to approximate the true reverse conditional, which has a closed form because the forward process is Gaussian!
Implementation in PyTorch
Let's implement a simple VAE to solidify our understanding of the ELBO and reparameterization trick:
Now let's see how the same principles apply to diffusion model training:
Connection to ELBO
Summary
Variational inference provides the theoretical foundation for training generative models when exact likelihood computation is intractable. Here are the key takeaways:
- The Intractability Problem: Computing is intractable for nonlinear latent variable models
- The ELBO: We can maximize a lower bound on log-likelihood that decomposes into reconstruction and regularization terms
- Reparameterization: By writing, we enable gradient-based optimization through sampling
- Mean-Field: Assuming independent latent dimensions gives tractable optimization with closed-form KL terms
- Diffusion Connection: Diffusion models are hierarchical VAEs with fixed encoders and learned decoders, trained using the same variational principles
Looking Ahead: In the next section, we'll explore Markov chains and stochastic processes - the mathematical framework that describes how diffusion models progressively add and remove noise. Understanding these processes will complete our prerequisite toolkit for diving into diffusion models proper.