Chapter 0
25 min read
Section 4 of 76

Variational Inference Primer

Prerequisites

Learning Objectives

By the end of this section, you will be able to:

  1. Understand latent variable models and why direct likelihood computation is often intractable
  2. Derive the Evidence Lower Bound (ELBO) and understand its two components: reconstruction and regularization
  3. Apply the reparameterization trick to enable gradient-based optimization through sampling operations
  4. Connect variational inference to diffusion models and understand how the diffusion ELBO decomposes into per-timestep terms
  5. Implement a simple VAE in PyTorch using the ELBO objective

The Big Picture: Inference as Optimization

Variational inference emerged from a profound insight: when exact probabilistic inference is computationally intractable, we can turn inference into an optimization problem. Instead of computing the true posterior distribution exactly, we search for the best approximation from a tractable family of distributions.

Historical Context: Variational methods have roots in physics (mean-field theory) and statistical mechanics. The application to machine learning was pioneered by researchers like Geoffrey Hinton and Michael Jordan in the 1990s. The modern era of variational inference began with Kingma and Welling's VAE paper (2013) and Rezende et al.'s work on stochastic variational inference.

The key insight is that we can convert an intractable integration problem into a tractable optimization problem. This is exactly what we need for diffusion models, where we want to learn a generative model without being able to compute exact likelihoods.

Exact InferenceVariational Inference
Compute true posterior p(z|x)Find approximate q(z) close to p(z|x)
Often intractable integrationTractable optimization
Exact but expensiveApproximate but scalable
Limited to conjugate modelsWorks with neural networks

Latent Variable Models

Latent variable models assume that observed data xx is generated through unobserved (latent) variables zz. The generative process is:

zp(z)(sample latent variable from prior)z \sim p(z) \quad \text{(sample latent variable from prior)}
xp(xz)(generate observation from latent)x \sim p(x|z) \quad \text{(generate observation from latent)}

The marginal likelihood of the data requires integrating over all possible latent configurations:

p(x)=p(xz)p(z)dzp(x) = \int p(x|z) p(z) \, dz

Why Latent Variables?

Latent variables serve multiple purposes in generative modeling:

  • Expressiveness: Simple distributions in latent space can map to complex distributions in data space
  • Disentanglement: Latent dimensions can capture independent factors of variation (e.g., pose, lighting, identity)
  • Compression: High-dimensional data can be encoded into lower-dimensional representations
  • Generation: New samples can be created by sampling latent codes and decoding them
Real-World Example (Vision): Consider face images. The latent space might capture age, expression, pose, and lighting as separate dimensions. A VAE learns this without supervision!

The Intractability Problem

For learning and inference, we need the posterior distribution:

p(zx)=p(xz)p(z)p(x)=p(xz)p(z)p(xz)p(z)dzp(z|x) = \frac{p(x|z) p(z)}{p(x)} = \frac{p(x|z) p(z)}{\int p(x|z') p(z') \, dz'}

The denominator p(x)p(x) (the evidence or marginal likelihood) requires integrating over all possible latent configurations. For continuous latent spaces and nonlinear relationships (like those modeled by neural networks), this integral is almost always intractable.

Why Is This Hard?

Consider a VAE with a 128-dimensional latent space. Computingp(x)p(x) exactly would require integrating over all points in R128\mathbb{R}^{128}. Even with Monte Carlo methods, the variance of such estimates would be prohibitive.

ApproachProblem
Exact integrationImpossible for nonlinear models
Grid-based methodsExponential in dimension
Naive Monte CarloExtremely high variance
Importance samplingRequires good proposal distribution

This is where variational inference comes to the rescue: instead of computing the true posterior, we approximate it with a simpler distribution that we can work with.


Variational Inference: Core Idea

The core idea of variational inference is to introduce an approximate posterior qϕ(zx)q_\phi(z|x) from a tractable family of distributions (parameterized by ϕ\phi) and optimize it to be as close as possible to the true posterior.

qϕ(zx)p(zx)q_\phi(z|x) \approx p(z|x)

We measure "closeness" using KL divergence (which we covered in the information theory section):

DKL(qϕ(zx)p(zx))=Eqϕ[logqϕ(zx)p(zx)]D_{\text{KL}}(q_\phi(z|x) \| p(z|x)) = \mathbb{E}_{q_\phi}\left[\log \frac{q_\phi(z|x)}{p(z|x)}\right]

Recall from the previous section that KL divergence is asymmetric. We use DKL(qp)D_{\text{KL}}(q \| p) (reverse KL) rather than DKL(pq)D_{\text{KL}}(p \| q) (forward KL) because it leads to a tractable objective.

Key Insight: The reverse KL tends to produce mode-seeking behavior: qq will concentrate on regions where pp has high probability, rather than trying to cover all of pp's support. This is important for understanding VAE behavior.

Deriving the ELBO

We want to minimize DKL(qϕ(zx)p(zx))D_{\text{KL}}(q_\phi(z|x) \| p(z|x)), but this requires knowing p(zx)p(z|x), which in turn requires the intractable p(x)p(x). Here's the elegant solution:

Start with the definition of KL divergence:

DKL(qϕp)=Eqϕ[logqϕ(zx)logp(zx)]D_{\text{KL}}(q_\phi \| p) = \mathbb{E}_{q_\phi}\left[\log q_\phi(z|x) - \log p(z|x)\right]

Use Bayes' rule to expand p(zx)p(z|x):

DKL(qϕp)=Eqϕ[logqϕ(zx)logp(xz)logp(z)+logp(x)]D_{\text{KL}}(q_\phi \| p) = \mathbb{E}_{q_\phi}\left[\log q_\phi(z|x) - \log p(x|z) - \log p(z) + \log p(x)\right]

Since p(x)p(x) doesn't depend on zz, it comes out of the expectation:

DKL(qϕp)=logp(x)+Eqϕ[logqϕ(zx)logp(xz)logp(z)]D_{\text{KL}}(q_\phi \| p) = \log p(x) + \mathbb{E}_{q_\phi}\left[\log q_\phi(z|x) - \log p(x|z) - \log p(z)\right]

Rearranging:

logp(x)=DKL(qϕp)+Eqϕ[logp(xz)+logp(z)logqϕ(zx)]\log p(x) = D_{\text{KL}}(q_\phi \| p) + \mathbb{E}_{q_\phi}\left[\log p(x|z) + \log p(z) - \log q_\phi(z|x)\right]

The second term is the Evidence Lower Bound (ELBO):

L(ϕ,θ;x)=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)p(z))\mathcal{L}(\phi, \theta; x) = \mathbb{E}_{q_\phi(z|x)}\left[\log p_\theta(x|z)\right] - D_{\text{KL}}(q_\phi(z|x) \| p(z))

Since KL divergence is always non-negative, we have the fundamental inequality:

logp(x)L(ϕ,θ;x)\log p(x) \geq \mathcal{L}(\phi, \theta; x)

This is why it's called a "lower bound" on the evidence! By maximizing the ELBO, we simultaneously:

  1. Maximize the log-likelihood logp(x)\log p(x) (as much as the bound allows)
  2. Minimize DKL(qϕp)D_{\text{KL}}(q_\phi \| p), making our approximation better

ELBO Decomposition

The ELBO has a beautiful interpretation with two competing terms:

L=Eqϕ(zx)[logpθ(xz)]ReconstructionDKL(qϕ(zx)p(z))Regularization\mathcal{L} = \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{Reconstruction}} - \underbrace{D_{\text{KL}}(q_\phi(z|x) \| p(z))}_{\text{Regularization}}

Reconstruction Term

Eqϕ(zx)[logpθ(xz)]\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]measures how well we can reconstruct the input from the latent code. This is the "quality" term that encourages the model to encode useful information in zz.

  • For continuous data with Gaussian likelihood: equivalent to negative MSE
  • For binary data with Bernoulli likelihood: equivalent to binary cross-entropy
  • Encourages the encoder to preserve information

Regularization (KL) Term

DKL(qϕ(zx)p(z))D_{\text{KL}}(q_\phi(z|x) \| p(z)) measures how much the approximate posterior deviates from the prior. This term:

  • Prevents the encoder from memorizing each data point
  • Encourages smooth, regular latent spaces
  • Enables generation by sampling from the prior
The Trade-off: If we only maximize reconstruction, the model becomes an autoencoder without generative capability. If we only minimize KL, we get random noise. The ELBO balances these naturally!

Mean-Field Approximation

The most common choice for the variational family is the mean-field approximation, where we assume the latent dimensions are independent:

qϕ(zx)=i=1dqϕ(zix)q_\phi(z|x) = \prod_{i=1}^{d} q_\phi(z_i|x)

For VAEs, we typically use a multivariate Gaussian with diagonal covariance:

qϕ(zx)=N(z;μϕ(x),diag(σϕ2(x)))q_\phi(z|x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma^2_\phi(x)))

The encoder neural network outputs the mean μ\mu and variance σ2\sigma^2 for each data point. In practice, we output logσ2\log \sigma^2 for numerical stability.

Advantages of Mean-Field

  • Closed-form KL: For Gaussian distributions with Gaussian prior, the KL term has an analytical solution
  • Simple parameterization: Only need to output mean and variance vectors
  • Efficient sampling: Can sample from each dimension independently

The Closed-Form KL for Gaussians

When q=N(μ,σ2)q = \mathcal{N}(\mu, \sigma^2) andp=N(0,1)p = \mathcal{N}(0, 1) (standard normal prior), the KL divergence has a simple form:

DKL(qp)=12i=1d(1+logσi2μi2σi2)D_{\text{KL}}(q \| p) = -\frac{1}{2}\sum_{i=1}^{d}\left(1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2\right)

This is the famous KL regularization term used in VAEs!


The Reparameterization Trick

There's one remaining challenge: the ELBO requires taking expectations over qϕ(zx)q_\phi(z|x), which depends on the parameters ϕ\phi we want to optimize. We need to estimate gradients through this expectation.

The problem: If we sample zqϕ(zx)z \sim q_\phi(z|x)directly, this sampling operation is not differentiable! We can't compute z/ϕ\partial z / \partial \phi.

The solution: The reparameterization trick rewrites the sampling operation as a deterministic transformation of a parameter-free random variable:

ϵN(0,I)(sample noise)\epsilon \sim \mathcal{N}(0, I) \quad \text{(sample noise)}
z=μϕ(x)+σϕ(x)ϵ(deterministic transform)z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon \quad \text{(deterministic transform)}

Now the randomness is in ϵ\epsilon, which doesn't depend on ϕ\phi. The gradients flow throughμ\mu and σ\sigma!

Why This Works

Consider the gradient of the reconstruction term:

ϕEqϕ[logp(xz)]=Eϵ[ϕlogp(xμϕ(x)+σϕ(x)ϵ)]\nabla_\phi \mathbb{E}_{q_\phi}[\log p(x|z)] = \mathbb{E}_{\epsilon}\left[\nabla_\phi \log p(x|\mu_\phi(x) + \sigma_\phi(x) \cdot \epsilon)\right]

The expectation is now over ϵ\epsilon, which doesn't depend on ϕ\phi. We can:

  1. Sample ϵ1,,ϵK\epsilon_1, \ldots, \epsilon_K from N(0,I)\mathcal{N}(0, I)
  2. Compute zk=μ+σϵkz_k = \mu + \sigma \cdot \epsilon_k for each sample
  3. Estimate the gradient with Monte Carlo:1Kkϕlogp(xzk)\frac{1}{K}\sum_k \nabla_\phi \log p(x|z_k)
Key Insight for Diffusion: The reparameterization trick is exactly how we train diffusion models! When we samplext=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, we're using the same reparameterization to enable gradient flow.

Connection to Diffusion Models

Diffusion models are a special case of hierarchical variational autoencoders! The key insight is that diffusion models have a fixed, predetermined forward process (the encoder) and only learn the reverse process (the decoder).

The Diffusion ELBO

In diffusion models, the ELBO decomposes into a sum over timesteps:

L=Eq[logpθ(x0x1)+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))+DKL(q(xTx0)p(xT))]\mathcal{L} = \mathbb{E}_q\left[-\log p_\theta(x_0|x_1) + \sum_{t=2}^{T} D_{\text{KL}}(q(x_{t-1}|x_t, x_0) \| p_\theta(x_{t-1}|x_t)) + D_{\text{KL}}(q(x_T|x_0) \| p(x_T))\right]
TermMeaningDiffusion Interpretation
-log p(x_0|x_1)Reconstruction from first latentFinal denoising step quality
Sum of KL termsMatch reverse to true posteriorEach denoising step matches optimal transition
KL(q(x_T|x_0) || p(x_T))Match final latent to priorNoisy image should look like pure noise

The Simplified Loss

Ho et al. (2020) showed that this complex ELBO simplifies to a remarkably simple objective when parameterized correctly:

Lsimple=Et,x0,ϵ[ϵϵθ(xt,t)2]L_{\text{simple}} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]

This is just predicting the noise that was added! The connection to variational inference explains why this works: we're still maximizing a lower bound on the log-likelihood, just in a computationally convenient form.

Deep Connection: The diffusion model training objective is derived from variational inference principles. Each denoising step learns to approximate the true reverse conditionalq(xt1xt,x0)q(x_{t-1}|x_t, x_0), which has a closed form because the forward process is Gaussian!

Implementation in PyTorch

Let's implement a simple VAE to solidify our understanding of the ELBO and reparameterization trick:

VAE with ELBO Loss
🐍vae.py
1

Import PyTorch neural network modules

2

Import functional API for activations

4

Define a simple VAE class inheriting from nn.Module

6Encoder

Initialize encoder and decoder networks. Encoder maps x to (mu, logvar), decoder reconstructs x from z

12Reparameterization

Reparameterization trick: sample z = mu + std * epsilon. This makes sampling differentiable!

17Forward Pass

Forward pass: encode to get distribution parameters, sample using reparam trick, decode to reconstruct

24ELBO Loss

ELBO loss = reconstruction + KL regularization. Reconstruction uses BCE, KL has closed-form for Gaussians

23 lines without explanation
1import torch.nn as nn
2import torch.nn.functional as F
3
4class SimpleVAE(nn.Module):
5    def __init__(self, input_dim=784, latent_dim=32):
6        super().__init__()
7        self.encoder = nn.Sequential(nn.Linear(input_dim, 256), nn.ReLU())
8        self.fc_mu = nn.Linear(256, latent_dim)
9        self.fc_logvar = nn.Linear(256, latent_dim)
10        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, input_dim))
11
12    def reparameterize(self, mu, logvar):
13        std = torch.exp(0.5 * logvar)  # Convert log-variance to std
14        eps = torch.randn_like(std)     # Sample noise: eps ~ N(0, I)
15        z = mu + eps * std              # Reparameterization: z = mu + std * eps
16        return z
17
18    def forward(self, x):
19        h = self.encoder(x.view(-1, 784))
20        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
21        z = self.reparameterize(mu, logvar)
22        recon = torch.sigmoid(self.decoder(z))
23        return recon, mu, logvar
24
25def vae_loss(recon_x, x, mu, logvar):
26    # Reconstruction loss (binary cross-entropy)
27    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
28    # KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
29    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
30    return BCE + KLD  # ELBO = -BCE - KLD, so we minimize BCE + KLD

Now let's see how the same principles apply to diffusion model training:

Diffusion ELBO (Simplified)
🐍diffusion_loss.py
1

Import noise scheduler and model - the core components of diffusion training

4

Compute cumulative product of (1 - beta) values - the noise schedule

7Loss Function

Diffusion ELBO loss function: simplified to predicting the added noise

11

Sample random timesteps and noise - training happens across all timesteps

16Noisy Image

Create noisy version: x_t = sqrt(alpha_bar)*x_0 + sqrt(1-alpha_bar)*noise

20Predict Noise

Predict noise with the model, then compute MSE loss. This is the simplified ELBO!

17 lines without explanation
1from diffusers import DDPMScheduler
2import torch
3
4# Setup noise schedule
5scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")
6alphas_cumprod = scheduler.alphas_cumprod
7
8def diffusion_loss(model, x_0, device):
9    """Simplified diffusion ELBO loss - just predict the noise!"""
10    batch_size = x_0.shape[0]
11    # Sample random timesteps
12    t = torch.randint(0, 1000, (batch_size,), device=device)
13    noise = torch.randn_like(x_0)
14
15    # Create noisy image using reparameterization: x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * eps
16    sqrt_alpha = alphas_cumprod[t].sqrt().view(-1, 1, 1, 1)
17    sqrt_one_minus_alpha = (1 - alphas_cumprod[t]).sqrt().view(-1, 1, 1, 1)
18    x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
19
20    # Model predicts the noise
21    predicted_noise = model(x_t, t)
22    loss = F.mse_loss(predicted_noise, noise)  # Simple L2 loss on noise prediction
23    return loss

Connection to ELBO

Notice how both implementations use the reparameterization trick! In VAEs, we reparameterize the latent sampling. In diffusion models, we reparameterize the noising process. Both enable gradient-based optimization of variational objectives.

Summary

Variational inference provides the theoretical foundation for training generative models when exact likelihood computation is intractable. Here are the key takeaways:

  1. The Intractability Problem: Computingp(x)=p(xz)p(z)dzp(x) = \int p(x|z) p(z) dz is intractable for nonlinear latent variable models
  2. The ELBO: We can maximize a lower bound on log-likelihood that decomposes into reconstruction and regularization terms
  3. Reparameterization: By writingz=μ+σϵz = \mu + \sigma \cdot \epsilon, we enable gradient-based optimization through sampling
  4. Mean-Field: Assuming independent latent dimensions gives tractable optimization with closed-form KL terms
  5. Diffusion Connection: Diffusion models are hierarchical VAEs with fixed encoders and learned decoders, trained using the same variational principles
Looking Ahead: In the next section, we'll explore Markov chains and stochastic processes - the mathematical framework that describes how diffusion models progressively add and remove noise. Understanding these processes will complete our prerequisite toolkit for diving into diffusion models proper.