Chapter 13
20 min read
Section 60 of 76

The VAE Component

Latent Diffusion Models

Learning Objectives

By the end of this section, you will be able to:

  1. Explain the role of the VAE in Latent Diffusion Models and why it's trained separately from the diffusion model
  2. Describe the encoder architecture including downsampling blocks, residual connections, and attention layers
  3. Understand KL regularization and why a small KL weight is crucial for high-quality reconstruction
  4. Compare reconstruction losses (MSE, L1, perceptual, adversarial) and their impact on image quality
  5. Analyze latent space properties that make it suitable for diffusion

VAE Fundamentals Review

Before diving into the specifics of LDM's VAE, let's review the core VAE framework. A Variational Autoencoder consists of:

The Probabilistic Interpretation

The encoder maps input xx to a distribution over latents:

qϕ(zx)=N(z;μϕ(x),σϕ2(x)I)q_\phi(z|x) = \mathcal{N}(z; \mu_\phi(x), \sigma^2_\phi(x) I)

The decoder maps latents back to a reconstruction:

pθ(xz)=N(x;μθ(z),σ2I)p_\theta(x|z) = \mathcal{N}(x; \mu_\theta(z), \sigma^2 I)

The ELBO Objective

VAEs are trained to maximize the Evidence Lower Bound:

L=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)p(z))\mathcal{L} = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x) \| p(z))

The first term is reconstruction (how well can we recover x from z). The second term is regularization (keep the latent distribution close to a standard normal).

The LDM Twist: Standard VAEs use a strong KL penalty, leading to blurry reconstructions. LDMs use a very weak KL penalty (weight ~0.00001), prioritizing reconstruction quality. The diffusion model will handle the prior matching later!

Encoder Architecture

The encoder in Stable Diffusion's VAE is a convolutional neural network with the following structure:

Architecture Overview

StageInput SizeOutput SizeComponents
Input512 x 512 x 3512 x 512 x 128Conv 3x3
Down 1512 x 512 x 128256 x 256 x 1282x ResBlock + Downsample
Down 2256 x 256 x 128128 x 128 x 2562x ResBlock + Downsample
Down 3128 x 128 x 25664 x 64 x 5122x ResBlock + Downsample
Down 464 x 64 x 51264 x 64 x 5122x ResBlock (no downsample)
Mid64 x 64 x 51264 x 64 x 512ResBlock + Attention + ResBlock
Output64 x 64 x 51264 x 64 x 8GroupNorm + SiLU + Conv

The final 8 channels encode the mean and log-variance (4 channels each) of the latent distribution.

VAE Encoder
🐍vae_encoder.py
1

Define the VAE encoder that maps images to latent distributions

5Initial Conv

First convolution to increase channels from 3 to 128

8ResNet Blocks

Stack of residual blocks with downsampling - the core encoder

15Attention

Self-attention at lowest resolution for global context

18Output

Final conv produces 2*latent_channels for mean and log_var

15 lines without explanation
1class Encoder(nn.Module):
2    def __init__(self, in_channels=3, latent_channels=4, ch=128):
3        super().__init__()
4
5        # Initial convolution: 3 -> 128 channels
6        self.conv_in = nn.Conv2d(in_channels, ch, 3, padding=1)
7
8        # Downsampling blocks: progressively reduce spatial resolution
9        self.down_blocks = nn.ModuleList([
10            DownBlock(ch, ch, downsample=True),      # 512 -> 256
11            DownBlock(ch, ch*2, downsample=True),    # 256 -> 128
12            DownBlock(ch*2, ch*4, downsample=True),  # 128 -> 64
13            DownBlock(ch*4, ch*4, downsample=False), # 64 -> 64
14        ])
15
16        # Middle block with attention for global context
17        self.mid_block = MidBlock(ch*4, use_attention=True)
18
19        # Output: 512 -> 8 (4 for mean, 4 for log_var)
20        self.conv_out = nn.Conv2d(ch*4, 2*latent_channels, 3, padding=1)

Residual Blocks

Each ResBlock contains:

  • GroupNorm (32 groups) for stable training
  • SiLU (Swish) activation: SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x)
  • Two 3x3 convolutions with skip connection
  • Optional 1x1 conv for channel dimension changes

The Attention Layer

A single self-attention layer is used in the middle block at 64x64 resolution. This provides global context without the quadratic cost of attention at higher resolutions:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

At 64x64, attention has 4,096 tokens - manageable compared to 262,144 at 512x512.


Decoder Architecture

The decoder mirrors the encoder, using upsampling instead of downsampling:

StageInput SizeOutput SizeComponents
Input64 x 64 x 464 x 64 x 512Conv 3x3
Mid64 x 64 x 51264 x 64 x 512ResBlock + Attention + ResBlock
Up 164 x 64 x 51264 x 64 x 5123x ResBlock (no upsample)
Up 264 x 64 x 512128 x 128 x 5123x ResBlock + Upsample
Up 3128 x 128 x 512256 x 256 x 2563x ResBlock + Upsample
Up 4256 x 256 x 256512 x 512 x 1283x ResBlock + Upsample
Output512 x 512 x 128512 x 512 x 3GroupNorm + SiLU + Conv
VAE Decoder
🐍vae_decoder.py
1

Define the VAE decoder that maps latents back to images

5Input Conv

Process latent input before upsampling

8ResNet + Upsample

Residual blocks with upsampling to increase resolution

15Final Conv

Map features back to RGB image (3 channels)

16 lines without explanation
1class Decoder(nn.Module):
2    def __init__(self, latent_channels=4, out_channels=3, ch=128):
3        super().__init__()
4
5        # Input convolution: 4 -> 512 channels
6        self.conv_in = nn.Conv2d(latent_channels, ch*4, 3, padding=1)
7
8        # Middle block with attention
9        self.mid_block = MidBlock(ch*4, use_attention=True)
10
11        # Upsampling blocks: progressively increase resolution
12        self.up_blocks = nn.ModuleList([
13            UpBlock(ch*4, ch*4, upsample=False),   # 64 -> 64
14            UpBlock(ch*4, ch*4, upsample=True),    # 64 -> 128
15            UpBlock(ch*4, ch*2, upsample=True),    # 128 -> 256
16            UpBlock(ch*2, ch, upsample=True),      # 256 -> 512
17        ])
18
19        # Output: 128 -> 3 RGB channels
20        self.conv_out = nn.Conv2d(ch, out_channels, 3, padding=1)

Upsampling Strategy

The decoder uses nearest-neighbor upsampling followed by a convolution, rather than transposed convolutions:

  • Avoids checkerboard artifacts common with transposed convolutions
  • More stable gradients during training
  • Easier to control output resolution

KL Regularization

The KL divergence term encourages the encoder to produce latents that follow a standard normal distribution:

DKL(qϕ(zx)N(0,I))=12i=1d(μi2+σi2logσi21)D_{KL}(q_\phi(z|x) \| \mathcal{N}(0, I)) = \frac{1}{2} \sum_{i=1}^d \left( \mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1 \right)

The Weight Matters

Standard VAEs use KL weight β=1\beta = 1, which leads to:

  • Posterior collapse: Encoder ignores input, produces same latent for all images
  • Blurry reconstructions: Can't encode fine details
  • Limited capacity: Information bottleneck too severe

LDM's VAE uses β106\beta \approx 10^{-6} - nearly zero! This means:

  • High-fidelity reconstruction: Encode all details
  • Slightly non-standard latents: Not exactly Gaussian, but close enough
  • Diffusion handles the rest: The diffusion model learns the actual latent distribution
Design Choice: We don't need the VAE to produce perfect Gaussian latents. We just need consistent, high-quality compression. The diffusion model will learn whatever distribution the encoder actually produces.

Reconstruction Quality

The choice of reconstruction loss dramatically affects output quality:

Loss Function Comparison

LossFormulaProsCons
MSE (L2)(x - x_hat)^2Simple, smooth gradientsBlurry outputs, averages modes
L1|x - x_hat|Sharper than MSEStill pixel-level
Perceptual (LPIPS)||VGG(x) - VGG(x_hat)||Matches human perceptionSlower, needs pretrained model
AdversarialD(x_hat) lossSharp, realistic detailsTraining instability, mode collapse

The SD-VAE Loss

Stable Diffusion's VAE uses a combination:

VAE Loss Function
🐍vae_loss.py
1

Compute the VAE loss combining reconstruction and KL terms

4Reconstruction

Perceptual loss using LPIPS gives better visual quality than MSE

8KL Loss

Regularize latent distribution toward standard normal

12Balancing

Small KL weight prevents posterior collapse while maintaining compression

16 lines without explanation
1def compute_vae_loss(x, x_recon, mu, log_var, lpips_model, discriminator):
2    """Compute combined VAE loss for high-quality reconstruction."""
3
4    # Perceptual reconstruction loss (more important than pixel loss)
5    recon_loss = lpips_model(x, x_recon).mean()  # LPIPS perceptual distance
6    recon_loss += 0.1 * F.l1_loss(x, x_recon)    # Small L1 for pixel accuracy
7
8    # KL divergence with very small weight
9    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
10    kl_loss = kl_loss / x.numel()  # Normalize by dimensionality
11
12    # Combine with tiny KL weight to prioritize reconstruction
13    total_loss = recon_loss + 1e-6 * kl_loss
14
15    # Optional: Add GAN loss for sharper outputs
16    if discriminator is not None:
17        gan_loss = -torch.mean(discriminator(x_recon))
18        total_loss += 0.1 * gan_loss
19
20    return total_loss

Quality Metrics

The SD-VAE achieves excellent reconstruction quality:

MetricValueInterpretation
PSNR~32 dBExcellent - nearly imperceptible loss
SSIM~0.95Very high structural similarity
LPIPS~0.05Low perceptual distance
FID (recon)~1.0Reconstruction nearly indistinguishable

Latent Space Properties

The VAE's latent space has several properties that make it well-suited for diffusion:

1. Spatial Correspondence

The 64x64 latent grid maintains spatial correspondence with the 512x512 image. Each latent "pixel" corresponds to an 8x8 patch in the original image:

  • Position (i, j) in latent space maps to patch (8i:8i+8, 8j:8j+8) in pixel space
  • Local edits in latent space produce local edits in the image
  • The U-Net can use standard convolutional operations

2. Smooth Interpolation

Linear interpolation in latent space produces semantically meaningful transitions:

zα=(1α)z1+αz2xα=D(zα)z_{\alpha} = (1-\alpha) z_1 + \alpha z_2 \quad \Rightarrow \quad x_{\alpha} = \mathcal{D}(z_\alpha)

As α\alpha varies from 0 to 1, the decoded image smoothly morphs between the two source images.

3. Approximate Gaussianity

Despite the weak KL regularization, the latent distribution is approximately Gaussian:

  • Mean: Close to 0 (within ±0.5\pm 0.5)
  • Variance: Close to 1 (scaled by a factor ~0.18 in SD)
  • Shape: Unimodal, roughly symmetric

Scaling Factor

Stable Diffusion multiplies latents by 0.18215 before diffusion and divides after. This scaling ensures the latent variance matches the noise schedule assumptions. Without this, the diffusion model would need different noise levels.

4. Semantic Disentanglement

The latent space exhibits some degree of disentanglement:

  • Content: Encoded in spatial structure of latent
  • Style: Partially separated in channel dimensions
  • Color: Can be manipulated somewhat independently

However, the disentanglement is not perfect - the VAE was trained for reconstruction, not interpretability.


Summary

The VAE component of LDMs provides efficient, high-quality image compression:

  1. Architecture: Encoder and decoder are symmetric convolutional networks with residual blocks, downsampling/upsampling, and one attention layer
  2. Minimal KL regularization: Weight ~10^-6 prioritizes reconstruction quality over strict Gaussian latents
  3. Perceptual loss: LPIPS + L1 + optional GAN produces sharp, detailed reconstructions
  4. Quality metrics: PSNR >30dB, LPIPS ~0.05 - nearly lossless perceptually
  5. Latent properties: Spatial correspondence, smooth interpolation, approximate Gaussianity, partial disentanglement
Looking Ahead: In the next section, we'll see how the diffusion process operates in this latent space - including noise schedule adaptations and the scaling factor that ensures compatibility between the VAE and diffusion model.