Learning Objectives
By the end of this section, you will be able to:
- Explain why pixel-space diffusion is computationally prohibitive at high resolutions
- Describe how VAEs compress images to a tractable latent space
- Understand the latent diffusion architecture and training procedure
- Implement a complete latent diffusion training pipeline
- Compare KL-VAE and VQ-VAE approaches for latent compression
The Pixel Space Problem
Early diffusion models operated directly on pixel space. While theoretically sound, this approach faces severe computational challenges at high resolutions:
The Scaling Problem
| Resolution | Pixels | Attention Memory | Practical? |
|---|---|---|---|
| 64x64 | 4,096 | ~16 MB | Yes |
| 256x256 | 65,536 | ~4 GB | Difficult |
| 512x512 | 262,144 | ~64 GB | Very difficult |
| 1024x1024 | 1,048,576 | ~1 TB | Impossible |
The problem: attention has complexity where is the number of pixels. For 512x512 images, that's 262,144 tokens - each attending to all others.
The Key Insight: Most of the information in an image is redundant. A 512x512x3 image has 786,432 values, but could be represented by a much smaller set of meaningful features. Why waste compute on pixel-level noise?
What Do We Really Need?
Diffusion models primarily learn:
- Semantic content: What objects and scenes are present
- Composition: Where things are spatially arranged
- Style: Colors, textures, artistic qualities
High-frequency pixel details (exact edge pixels, precise textures) are largely perceptual - our eyes fill in the details. The solution: work in a compressed semantic space.
VAE Compression
Variational Autoencoders (VAEs) learn to compress images to a low-dimensional latent space while preserving essential information:
The Compression-Reconstruction Trade-off
- : Original image [512, 512, 3]
- : Latent representation [64, 64, 4]
- : Reconstructed image [512, 512, 3]
This is a 64x compression in spatial dimensions (8x8) with only 4 channels. The total compression ratio is:
Why 8x Spatial Compression?
- 4x: Still too expensive, limited compression
- 8x: Good balance of quality and efficiency
- 16x: Loses too much detail, blurry outputs
VAE Architecture
Decoder
Visualizing Latent Space
Latent Diffusion Architecture
Latent Diffusion Models (LDM) perform the diffusion process in the compressed latent space instead of pixel space:
Architecture Overview
Key Components
| Component | Role | Trainable? |
|---|---|---|
| VAE Encoder | Compress images to latents | No (pre-trained) |
| VAE Decoder | Decompress latents to images | No (pre-trained) |
| U-Net | Denoise latents conditioned on text | Yes |
| Text Encoder | Encode prompts to embeddings | No (pre-trained) |
The Efficiency Gain: Instead of denoising 512x512x3 = 786,432 values, we denoise 64x64x4 = 16,384 values. That's 48x fewer values, and attention scales quadratically, so the speedup in attention is even more dramatic: ~2300x.
Why a Pre-trained VAE?
The VAE is trained separately and then frozen during diffusion training. This separation has several benefits:
- Stable training: VAE reconstruction objective is simpler than joint training
- Reusable: Same VAE can be used for different diffusion models
- Memory efficient: VAE weights don't need gradients during diffusion training
- Quality: VAE can be trained with perceptual and adversarial losses for sharp reconstructions
PyTorch Implementation
Here's a complete latent diffusion model implementation:
The Scale Factor
The magic number deserves explanation:
- VAE latents have a certain variance distribution after training
- Diffusion works best when inputs have unit variance
- The scale factor normalizes latent variance to approximately 1
- It's computed as from the training data
Complete Training Pipeline
1def train_latent_diffusion(
2 model: LatentDiffusion,
3 dataloader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 num_epochs: int,
6 device: str = "cuda",
7):
8 """
9 Complete training loop for Latent Diffusion.
10 """
11 model.train()
12 model.unet.train() # Only U-Net is trainable
13
14 for epoch in range(num_epochs):
15 epoch_loss = 0
16 for batch in dataloader:
17 images = batch["images"].to(device)
18 prompts = batch["prompts"]
19
20 # Forward pass
21 loss = model.training_step(images, prompts)
22
23 # Backward pass
24 optimizer.zero_grad()
25 loss.backward()
26 optimizer.step()
27
28 epoch_loss += loss.item()
29
30 avg_loss = epoch_loss / len(dataloader)
31 print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
32
33
34# Sampling (inference)
35@torch.no_grad()
36def sample(
37 model: LatentDiffusion,
38 prompts: list[str],
39 height: int = 512,
40 width: int = 512,
41 num_inference_steps: int = 50,
42 guidance_scale: float = 7.5,
43 device: str = "cuda",
44) -> torch.Tensor:
45 """
46 Generate images from text prompts using CFG.
47 """
48 batch_size = len(prompts)
49 latent_height = height // 8
50 latent_width = width // 8
51
52 # Get text embeddings
53 text_embeddings = model.text_encoder(prompts)
54 uncond_embeddings = model.text_encoder([""] * batch_size)
55
56 # CFG: concatenate unconditional and conditional
57 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
58
59 # Initialize latents from noise
60 latents = torch.randn(
61 batch_size, 4, latent_height, latent_width, device=device
62 )
63
64 # Setup scheduler for inference
65 model.scheduler.set_timesteps(num_inference_steps)
66
67 # Denoising loop
68 for t in model.scheduler.timesteps:
69 # CFG: double the batch
70 latent_input = torch.cat([latents] * 2)
71 timestep = torch.tensor([t] * batch_size * 2, device=device)
72
73 # Predict noise
74 noise_pred = model.unet(
75 latent_input,
76 timestep,
77 encoder_hidden_states=text_embeddings,
78 ).sample
79
80 # CFG: combine conditional and unconditional
81 noise_uncond, noise_cond = noise_pred.chunk(2)
82 noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
83
84 # Scheduler step
85 latents = model.scheduler.step(noise_pred, t, latents).prev_sample
86
87 # Decode latents to images
88 images = model.decode_latents(latents)
89 return imagesMemory Optimization Tips
- Gradient checkpointing: Trade compute for memory in U-Net
- Mixed precision (fp16): Halves memory, speeds up training
- xFormers: Memory-efficient attention implementation
- Gradient accumulation: Simulate larger batch sizes
KL-VAE vs VQ-VAE
Two main approaches exist for the autoencoder component:
KL-VAE (Continuous Latents)
Used by Stable Diffusion. Latents are continuous Gaussian:
- Pros: Smooth latent space, natural for diffusion
- Cons: Can be blurry without adversarial training
VQ-VAE (Discrete Latents)
Used by DALL-E, Parti. Latents are discrete codebook indices:
- Pros: Sharp reconstructions, prevents posterior collapse
- Cons: Discrete space less natural for continuous diffusion
Comparison
| Aspect | KL-VAE | VQ-VAE |
|---|---|---|
| Latent type | Continuous Gaussian | Discrete codebook |
| Diffusion compatibility | Natural fit | Requires adaptation |
| Reconstruction quality | Good with GAN loss | Excellent |
| Used by | Stable Diffusion, FLUX | DALL-E, Parti, MaskGIT |
| Training | Simpler | Codebook collapse issues |
Modern Trend: Recent models like Stable Diffusion 3 and FLUX use improved KL-VAEs with adversarial training, achieving excellent reconstruction quality while maintaining smooth latent spaces ideal for diffusion.
Key Takeaways
- Pixel-space diffusion doesn't scale: Attention has O(N^2) complexity, making high-resolution generation prohibitive
- VAEs provide 48x compression: 512x512x3 images become 64x64x4 latents while preserving semantic content
- Latent diffusion separates compression from generation:Pre-trained VAE handles reconstruction, diffusion handles generation
- Only the U-Net is trained: VAE and text encoder are frozen, saving memory and preserving their learned representations
- Scale factor normalizes latent variance: The magic 0.18215 keeps latent variance around 1 for stable diffusion training
- KL-VAE is preferred for diffusion: Continuous latent space is a natural fit for the continuous diffusion process
Part IV Complete: You now understand the complete text-to-image pipeline: text encoding with CLIP, cross-attention for spatial alignment, classifier-free guidance for controllable generation, and latent diffusion for efficient high-resolution synthesis. These foundations power systems like Stable Diffusion, DALL-E 3, and Midjourney.