Chapter 9
20 min read
Section 47 of 76

From Pixel Space to Latent Space

Text-to-Image Foundations

Learning Objectives

By the end of this section, you will be able to:

  1. Explain why pixel-space diffusion is computationally prohibitive at high resolutions
  2. Describe how VAEs compress images to a tractable latent space
  3. Understand the latent diffusion architecture and training procedure
  4. Implement a complete latent diffusion training pipeline
  5. Compare KL-VAE and VQ-VAE approaches for latent compression

The Pixel Space Problem

Early diffusion models operated directly on pixel space. While theoretically sound, this approach faces severe computational challenges at high resolutions:

The Scaling Problem

ResolutionPixelsAttention MemoryPractical?
64x644,096~16 MBYes
256x25665,536~4 GBDifficult
512x512262,144~64 GBVery difficult
1024x10241,048,576~1 TBImpossible

The problem: attention has O(N2)O(N^2) complexity where NN is the number of pixels. For 512x512 images, that's 262,144 tokens - each attending to all others.

The Key Insight: Most of the information in an image is redundant. A 512x512x3 image has 786,432 values, but could be represented by a much smaller set of meaningful features. Why waste compute on pixel-level noise?

What Do We Really Need?

Diffusion models primarily learn:

  • Semantic content: What objects and scenes are present
  • Composition: Where things are spatially arranged
  • Style: Colors, textures, artistic qualities

High-frequency pixel details (exact edge pixels, precise textures) are largely perceptual - our eyes fill in the details. The solution: work in a compressed semantic space.


VAE Compression

Variational Autoencoders (VAEs) learn to compress images to a low-dimensional latent space while preserving essential information:

The Compression-Reconstruction Trade-off

xEncoderzDecoderx^\mathbf{x} \xrightarrow{\text{Encoder}} \mathbf{z} \xrightarrow{\text{Decoder}} \hat{\mathbf{x}}

  • x\mathbf{x}: Original image [512, 512, 3]
  • z\mathbf{z}: Latent representation [64, 64, 4]
  • x^\hat{\mathbf{x}}: Reconstructed image [512, 512, 3]

This is a 64x compression in spatial dimensions (8x8) with only 4 channels. The total compression ratio is:

512×512×364×64×4=786,43216,38448×\frac{512 \times 512 \times 3}{64 \times 64 \times 4} = \frac{786,432}{16,384} \approx 48\times

Why 8x Spatial Compression?

Stable Diffusion uses 8x compression (512 -> 64). This is a sweet spot:
  • 4x: Still too expensive, limited compression
  • 8x: Good balance of quality and efficiency
  • 16x: Loses too much detail, blurry outputs

VAE Architecture

🐍python
1Imports

Standard PyTorch modules for building the encoder.

6Encoder Purpose

Compresses high-resolution images (e.g., 512x512x3) to low-resolution latents (64x64x4). The 8x spatial compression is key to efficiency.

12Architecture Parameters

Input is 3-channel RGB, output is 4-channel latent. Hidden channels progressively increase to capture more abstract features.

20Initial Convolution

Projects input RGB to the first hidden dimension. Maintains spatial resolution.

22Downsampling Blocks

Each block has two ResNet blocks for processing plus a strided convolution for 2x downsampling. 4 blocks = 16x total downsampling, but Stable Diffusion uses 8x.

33Middle Block

After downsampling, apply ResNet + Attention + ResNet. Attention at low resolution is efficient and helps capture global structure.

40Output Projection

Output 2x latent channels: mean and log_variance for the Gaussian latent distribution. This enables the reparameterization trick.

44Forward Pass

Process through all blocks, normalize, and split output into mean and log_var for the latent Gaussian distribution.

58 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class VAEEncoder(nn.Module):
6    """
7    Encoder that compresses images to latent space.
8    512x512x3 -> 64x64x4 (8x spatial compression)
9    """
10
11    def __init__(
12        self,
13        in_channels: int = 3,
14        latent_channels: int = 4,
15        hidden_channels: list[int] = [128, 256, 512, 512],
16    ):
17        super().__init__()
18
19        # Initial convolution
20        self.conv_in = nn.Conv2d(in_channels, hidden_channels[0], 3, padding=1)
21
22        # Downsampling blocks
23        self.down_blocks = nn.ModuleList()
24        in_ch = hidden_channels[0]
25        for out_ch in hidden_channels:
26            self.down_blocks.append(
27                nn.Sequential(
28                    ResnetBlock(in_ch, out_ch),
29                    ResnetBlock(out_ch, out_ch),
30                    nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1),  # Downsample
31                )
32            )
33            in_ch = out_ch
34
35        # Middle attention + resnet
36        self.mid = nn.Sequential(
37            ResnetBlock(hidden_channels[-1], hidden_channels[-1]),
38            AttentionBlock(hidden_channels[-1]),
39            ResnetBlock(hidden_channels[-1], hidden_channels[-1]),
40        )
41
42        # Output: mean and log_var for reparameterization
43        self.norm_out = nn.GroupNorm(32, hidden_channels[-1])
44        self.conv_out = nn.Conv2d(hidden_channels[-1], 2 * latent_channels, 3, padding=1)
45
46    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
47        """
48        Args:
49            x: Image tensor [B, 3, H, W]
50        Returns:
51            mean: [B, latent_channels, H/8, W/8]
52            log_var: [B, latent_channels, H/8, W/8]
53        """
54        h = self.conv_in(x)
55
56        for block in self.down_blocks:
57            h = block(h)
58
59        h = self.mid(h)
60        h = self.norm_out(h)
61        h = F.silu(h)
62        h = self.conv_out(h)
63
64        # Split into mean and log_variance
65        mean, log_var = h.chunk(2, dim=1)
66        return mean, log_var

Decoder

🐍python
1Decoder Purpose

Inverse of encoder: takes compressed latent and reconstructs full-resolution image.

7Architecture Parameters

Mirror of encoder: hidden channels decrease as we upsample. Output is 3-channel RGB.

16Input Projection

Project latent channels to first hidden dimension.

18Middle Block

Same structure as encoder middle: ResNet + Attention + ResNet for global processing.

25Upsampling Blocks

Each block processes features with ResNets, then upsamples 2x with nearest-neighbor interpolation followed by convolution (better than transposed conv for artifacts).

38Output

Normalize, activate, and project to 3 RGB channels. Output is in [-1, 1] range typically.

52 lines without explanation
1class VAEDecoder(nn.Module):
2    """
3    Decoder that reconstructs images from latent space.
4    64x64x4 -> 512x512x3
5    """
6
7    def __init__(
8        self,
9        latent_channels: int = 4,
10        out_channels: int = 3,
11        hidden_channels: list[int] = [512, 512, 256, 128],
12    ):
13        super().__init__()
14
15        # Input projection
16        self.conv_in = nn.Conv2d(latent_channels, hidden_channels[0], 3, padding=1)
17
18        # Middle block
19        self.mid = nn.Sequential(
20            ResnetBlock(hidden_channels[0], hidden_channels[0]),
21            AttentionBlock(hidden_channels[0]),
22            ResnetBlock(hidden_channels[0], hidden_channels[0]),
23        )
24
25        # Upsampling blocks
26        self.up_blocks = nn.ModuleList()
27        in_ch = hidden_channels[0]
28        for out_ch in hidden_channels:
29            self.up_blocks.append(
30                nn.Sequential(
31                    ResnetBlock(in_ch, out_ch),
32                    ResnetBlock(out_ch, out_ch),
33                    nn.Upsample(scale_factor=2, mode="nearest"),
34                    nn.Conv2d(out_ch, out_ch, 3, padding=1),
35                )
36            )
37            in_ch = out_ch
38
39        # Output
40        self.norm_out = nn.GroupNorm(32, hidden_channels[-1])
41        self.conv_out = nn.Conv2d(hidden_channels[-1], out_channels, 3, padding=1)
42
43    def forward(self, z: torch.Tensor) -> torch.Tensor:
44        """
45        Args:
46            z: Latent tensor [B, latent_channels, H/8, W/8]
47        Returns:
48            Reconstructed image [B, 3, H, W]
49        """
50        h = self.conv_in(z)
51        h = self.mid(h)
52
53        for block in self.up_blocks:
54            h = block(h)
55
56        h = self.norm_out(h)
57        h = F.silu(h)
58        return self.conv_out(h)

Visualizing Latent Space


Latent Diffusion Architecture

Latent Diffusion Models (LDM) perform the diffusion process in the compressed latent space instead of pixel space:

Architecture Overview

xVAE Encoderz0DiffusionzTN(0,I)\mathbf{x} \xrightarrow{\text{VAE Encoder}} \mathbf{z}_0 \xrightarrow{\text{Diffusion}} \mathbf{z}_T \sim \mathcal{N}(0, I)

zTDenoisez0VAE Decoderx^\mathbf{z}_T \xrightarrow{\text{Denoise}} \mathbf{z}_0 \xrightarrow{\text{VAE Decoder}} \hat{\mathbf{x}}

Key Components

ComponentRoleTrainable?
VAE EncoderCompress images to latentsNo (pre-trained)
VAE DecoderDecompress latents to imagesNo (pre-trained)
U-NetDenoise latents conditioned on textYes
Text EncoderEncode prompts to embeddingsNo (pre-trained)
The Efficiency Gain: Instead of denoising 512x512x3 = 786,432 values, we denoise 64x64x4 = 16,384 values. That's 48x fewer values, and attention scales quadratically, so the speedup in attention is even more dramatic: ~2300x.

Why a Pre-trained VAE?

The VAE is trained separately and then frozen during diffusion training. This separation has several benefits:

  • Stable training: VAE reconstruction objective is simpler than joint training
  • Reusable: Same VAE can be used for different diffusion models
  • Memory efficient: VAE weights don't need gradients during diffusion training
  • Quality: VAE can be trained with perceptual and adversarial losses for sharp reconstructions

PyTorch Implementation

Here's a complete latent diffusion model implementation:

🐍python
1LatentDiffusion Class

Combines all components: VAE for compression, U-Net for diffusion, CLIP for text encoding.

7Components

Pre-trained VAE and text encoder are frozen. Only the U-Net is trained. scale_factor normalizes latent variance.

22Freeze Components

Critical: VAE and text encoder don't need gradients. This saves memory and preserves their learned representations.

26Encode Images

Sample from VAE's latent distribution and scale by constant factor (0.18215 for SD). This keeps latent variance reasonable for diffusion.

34Decode Latents

Inverse: unscale and decode through VAE decoder to get images.

40Training Step

Standard diffusion training but operating on latents instead of pixels.

58Encode to Latent

First step: compress image to latent space. This is where the 8x compression happens.

61Text Encoding

Get CLIP embeddings for cross-attention. No gradients needed.

65Noise and Timesteps

Sample Gaussian noise and random timesteps, same as pixel-space diffusion.

74Predict and Loss

U-Net predicts noise given noisy latent + timestep + text. MSE loss on noise prediction.

74 lines without explanation
1class LatentDiffusion(nn.Module):
2    """
3    Complete Latent Diffusion Model.
4    Combines frozen VAE with trainable U-Net.
5    """
6
7    def __init__(
8        self,
9        vae: AutoencoderKL,          # Pre-trained VAE
10        unet: UNet2DConditionModel,  # Trainable U-Net
11        text_encoder: CLIPTextModel, # Frozen CLIP
12        noise_scheduler: DDPMScheduler,
13        scale_factor: float = 0.18215,  # Latent scaling
14    ):
15        super().__init__()
16        self.vae = vae
17        self.unet = unet
18        self.text_encoder = text_encoder
19        self.scheduler = noise_scheduler
20        self.scale_factor = scale_factor
21
22        # Freeze VAE and text encoder
23        self.vae.requires_grad_(False)
24        self.text_encoder.requires_grad_(False)
25
26    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
27        """Encode images to scaled latent space."""
28        with torch.no_grad():
29            latent_dist = self.vae.encode(images).latent_dist
30            latents = latent_dist.sample()
31            latents = latents * self.scale_factor
32        return latents
33
34    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
35        """Decode latents to images."""
36        latents = latents / self.scale_factor
37        with torch.no_grad():
38            images = self.vae.decode(latents).sample
39        return images
40
41    def training_step(
42        self,
43        images: torch.Tensor,
44        prompts: list[str],
45    ) -> torch.Tensor:
46        """
47        Single training step for latent diffusion.
48
49        Args:
50            images: [B, 3, H, W] images in [-1, 1]
51            prompts: List of text prompts
52
53        Returns:
54            Loss value
55        """
56        batch_size = images.shape[0]
57
58        # 1. Encode images to latent space
59        latents = self.encode_images(images)
60
61        # 2. Encode text prompts
62        with torch.no_grad():
63            text_embeddings = self.text_encoder(prompts)
64
65        # 3. Sample noise and timesteps
66        noise = torch.randn_like(latents)
67        timesteps = torch.randint(
68            0, self.scheduler.num_train_timesteps,
69            (batch_size,), device=latents.device
70        )
71
72        # 4. Add noise to latents
73        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
74
75        # 5. Predict noise with U-Net
76        noise_pred = self.unet(
77            noisy_latents,
78            timesteps,
79            encoder_hidden_states=text_embeddings,
80        ).sample
81
82        # 6. Compute loss
83        loss = F.mse_loss(noise_pred, noise)
84        return loss

The Scale Factor

The magic number 0.182150.18215 deserves explanation:

  • VAE latents have a certain variance distribution after training
  • Diffusion works best when inputs have unit variance
  • The scale factor normalizes latent variance to approximately 1
  • It's computed as 1/σlatent1 / \sigma_{\text{latent}}from the training data

Complete Training Pipeline

🐍python
1def train_latent_diffusion(
2    model: LatentDiffusion,
3    dataloader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    num_epochs: int,
6    device: str = "cuda",
7):
8    """
9    Complete training loop for Latent Diffusion.
10    """
11    model.train()
12    model.unet.train()  # Only U-Net is trainable
13
14    for epoch in range(num_epochs):
15        epoch_loss = 0
16        for batch in dataloader:
17            images = batch["images"].to(device)
18            prompts = batch["prompts"]
19
20            # Forward pass
21            loss = model.training_step(images, prompts)
22
23            # Backward pass
24            optimizer.zero_grad()
25            loss.backward()
26            optimizer.step()
27
28            epoch_loss += loss.item()
29
30        avg_loss = epoch_loss / len(dataloader)
31        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
32
33
34# Sampling (inference)
35@torch.no_grad()
36def sample(
37    model: LatentDiffusion,
38    prompts: list[str],
39    height: int = 512,
40    width: int = 512,
41    num_inference_steps: int = 50,
42    guidance_scale: float = 7.5,
43    device: str = "cuda",
44) -> torch.Tensor:
45    """
46    Generate images from text prompts using CFG.
47    """
48    batch_size = len(prompts)
49    latent_height = height // 8
50    latent_width = width // 8
51
52    # Get text embeddings
53    text_embeddings = model.text_encoder(prompts)
54    uncond_embeddings = model.text_encoder([""] * batch_size)
55
56    # CFG: concatenate unconditional and conditional
57    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
58
59    # Initialize latents from noise
60    latents = torch.randn(
61        batch_size, 4, latent_height, latent_width, device=device
62    )
63
64    # Setup scheduler for inference
65    model.scheduler.set_timesteps(num_inference_steps)
66
67    # Denoising loop
68    for t in model.scheduler.timesteps:
69        # CFG: double the batch
70        latent_input = torch.cat([latents] * 2)
71        timestep = torch.tensor([t] * batch_size * 2, device=device)
72
73        # Predict noise
74        noise_pred = model.unet(
75            latent_input,
76            timestep,
77            encoder_hidden_states=text_embeddings,
78        ).sample
79
80        # CFG: combine conditional and unconditional
81        noise_uncond, noise_cond = noise_pred.chunk(2)
82        noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
83
84        # Scheduler step
85        latents = model.scheduler.step(noise_pred, t, latents).prev_sample
86
87    # Decode latents to images
88    images = model.decode_latents(latents)
89    return images

Memory Optimization Tips

  • Gradient checkpointing: Trade compute for memory in U-Net
  • Mixed precision (fp16): Halves memory, speeds up training
  • xFormers: Memory-efficient attention implementation
  • Gradient accumulation: Simulate larger batch sizes

KL-VAE vs VQ-VAE

Two main approaches exist for the autoencoder component:

KL-VAE (Continuous Latents)

Used by Stable Diffusion. Latents are continuous Gaussian:

zN(μ,σ2)\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)

  • Pros: Smooth latent space, natural for diffusion
  • Cons: Can be blurry without adversarial training

VQ-VAE (Discrete Latents)

Used by DALL-E, Parti. Latents are discrete codebook indices:

zq=Quantize(ze)=argminekzeek\mathbf{z}_q = \text{Quantize}(\mathbf{z}_e) = \arg\min_{\mathbf{e}_k} \|\mathbf{z}_e - \mathbf{e}_k\|

  • Pros: Sharp reconstructions, prevents posterior collapse
  • Cons: Discrete space less natural for continuous diffusion

Comparison

AspectKL-VAEVQ-VAE
Latent typeContinuous GaussianDiscrete codebook
Diffusion compatibilityNatural fitRequires adaptation
Reconstruction qualityGood with GAN lossExcellent
Used byStable Diffusion, FLUXDALL-E, Parti, MaskGIT
TrainingSimplerCodebook collapse issues
Modern Trend: Recent models like Stable Diffusion 3 and FLUX use improved KL-VAEs with adversarial training, achieving excellent reconstruction quality while maintaining smooth latent spaces ideal for diffusion.

Key Takeaways

  1. Pixel-space diffusion doesn't scale: Attention has O(N^2) complexity, making high-resolution generation prohibitive
  2. VAEs provide 48x compression: 512x512x3 images become 64x64x4 latents while preserving semantic content
  3. Latent diffusion separates compression from generation:Pre-trained VAE handles reconstruction, diffusion handles generation
  4. Only the U-Net is trained: VAE and text encoder are frozen, saving memory and preserving their learned representations
  5. Scale factor normalizes latent variance: The magic 0.18215 keeps latent variance around 1 for stable diffusion training
  6. KL-VAE is preferred for diffusion: Continuous latent space is a natural fit for the continuous diffusion process
Part IV Complete: You now understand the complete text-to-image pipeline: text encoding with CLIP, cross-attention for spatial alignment, classifier-free guidance for controllable generation, and latent diffusion for efficient high-resolution synthesis. These foundations power systems like Stable Diffusion, DALL-E 3, and Midjourney.