Chapter 8
20 min read
Section 43 of 76

Implementing Classifier-Free Guidance

Conditional Generation

Learning Objectives

By the end of this section, you will be able to:

  1. Implement training with label dropout for CFG
  2. Write the CFG sampling loop with guidance scale
  3. Optimize inference with batched conditional/unconditional passes
  4. Tune guidance scale for different use cases
  5. Debug common issues with CFG implementations

Training with Label Dropout

The key to CFG is training a model that can operate in both conditional and unconditional modes. This is achieved by randomly dropping conditions during training.

CFG Training with Label Dropout
🐍cfg_training.py
11Condition Drop Probability

The probability of replacing labels with the null class. 10% is standard. This enables the model to learn both conditional and unconditional generation.

14Null Class Index

We add one extra class index to represent 'no condition'. When labels are dropped, they're replaced with this null class.

29Label Dropout

Randomly replace some labels with the null class. This is the key innovation of CFG - training one model that can do both conditional and unconditional.

40Conditional Prediction

The model receives the possibly-dropped labels. When it sees the null class, it learns to predict noise unconditionally.

48 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.utils.data import DataLoader
5
6class CFGTrainer:
7    """Trainer for Classifier-Free Guidance with label dropout."""
8
9    def __init__(
10        self,
11        model: nn.Module,
12        num_classes: int,
13        cond_drop_prob: float = 0.1,
14        num_timesteps: int = 1000,
15        device: str = "cuda",
16    ):
17        self.model = model
18        self.num_classes = num_classes
19        self.cond_drop_prob = cond_drop_prob
20        self.null_class = num_classes  # Extra class for unconditional
21        self.device = device
22
23        # Precompute noise schedule
24        betas = torch.linspace(1e-4, 0.02, num_timesteps)
25        alphas = 1 - betas
26        self.alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
27
28    def train_step(
29        self,
30        images: torch.Tensor,
31        labels: torch.Tensor,
32    ) -> torch.Tensor:
33        batch_size = images.shape[0]
34
35        # Apply label dropout for CFG training
36        drop_mask = torch.rand(batch_size, device=self.device) < self.cond_drop_prob
37        labels = torch.where(drop_mask, self.null_class, labels)
38
39        # Sample timesteps uniformly
40        t = torch.randint(0, len(self.alphas_cumprod), (batch_size,), device=self.device)
41
42        # Add noise to images
43        noise = torch.randn_like(images)
44        alpha_t = self.alphas_cumprod[t][:, None, None, None]
45        x_t = torch.sqrt(alpha_t) * images + torch.sqrt(1 - alpha_t) * noise
46
47        # Predict noise conditioned on (possibly dropped) labels
48        noise_pred = self.model(x_t, t, labels)
49
50        # Simple MSE loss
51        loss = F.mse_loss(noise_pred, noise)
52        return loss

Training Tips

  • Drop probability: 10-20% works well. Too high reduces conditional quality; too low hurts unconditional
  • Null embedding: Can be learned (extra class) or fixed (zeros). Learned often works better
  • Loss weighting: Standard MSE loss works; no special weighting needed

CFG Sampling Implementation

During sampling, we compute both conditional and unconditional predictions at each step, then combine them using the CFG formula.

Classifier-Free Guidance Sampling
🐍cfg_sampling.py
7Guidance Scale

The strength of conditioning. w=7.5 is a common default for Stable Diffusion. Higher values give stronger conditioning but may reduce quality.

19Prepare Labels

We need both the actual class label and the null (unconditional) label for CFG.

28Dual Forward Pass

Get predictions with and without conditioning. This is the computational cost of CFG - two forward passes per step.

32CFG Formula

The core equation: move from unconditional toward conditional, scaled by guidance_scale. When w>1, we extrapolate beyond conditional.

39Predict Clean Image

Use the guided noise prediction to estimate x_0. Clamping helps stability.

54 lines without explanation
1import torch
2import torch.nn as nn
3
4@torch.no_grad()
5def cfg_sample(
6    model: nn.Module,
7    class_label: int,
8    guidance_scale: float = 7.5,
9    num_steps: int = 1000,
10    img_shape: tuple = (1, 3, 64, 64),
11    null_class: int = 1000,  # For ImageNet
12    device: str = "cuda",
13) -> torch.Tensor:
14    """Generate image with Classifier-Free Guidance."""
15
16    # Start from pure noise
17    x = torch.randn(img_shape, device=device)
18
19    # Prepare class labels
20    cond_label = torch.tensor([class_label], device=device)
21    uncond_label = torch.tensor([null_class], device=device)
22
23    # Noise schedule
24    betas = torch.linspace(1e-4, 0.02, num_steps, device=device)
25    alphas = 1 - betas
26    alphas_cumprod = torch.cumprod(alphas, dim=0)
27
28    for i in reversed(range(num_steps)):
29        t = torch.tensor([i], device=device)
30
31        # Get conditional and unconditional predictions
32        eps_cond = model(x, t, cond_label)
33        eps_uncond = model(x, t, uncond_label)
34
35        # Apply CFG formula
36        eps_guided = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
37
38        # DDPM sampling step
39        alpha_t = alphas_cumprod[i]
40        alpha_prev = alphas_cumprod[i - 1] if i > 0 else torch.tensor(1.0)
41
42        # Predict x_0
43        x_0_pred = (x - torch.sqrt(1 - alpha_t) * eps_guided) / torch.sqrt(alpha_t)
44        x_0_pred = torch.clamp(x_0_pred, -1, 1)
45
46        # Compute posterior mean
47        coef1 = betas[i] * torch.sqrt(alpha_prev) / (1 - alpha_t)
48        coef2 = (1 - alpha_prev) * torch.sqrt(alphas[i]) / (1 - alpha_t)
49        mean = coef1 * x_0_pred + coef2 * x
50
51        # Add noise (except at final step)
52        if i > 0:
53            noise = torch.randn_like(x)
54            var = betas[i] * (1 - alpha_prev) / (1 - alpha_t)
55            x = mean + torch.sqrt(var) * noise
56        else:
57            x = mean
58
59    return x

CFG with DDIM

CFG works with any sampler (DDPM, DDIM, DPM++, etc.). The guided noise prediction replaces the standard prediction:

ϵ~=ϵuncond+w(ϵcondϵuncond)\tilde{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_{\text{uncond}} + w (\boldsymbol{\epsilon}_{\text{cond}} - \boldsymbol{\epsilon}_{\text{uncond}})

Then use this ϵ~\tilde{\boldsymbol{\epsilon}} in your chosen sampling algorithm.


Batched Inference Optimization

The naive CFG implementation requires two forward passes per step. We can optimize this by batching the conditional and unconditional passes together.

Batched CFG for Efficiency
🐍cfg_batched.py
18Doubled Batch Strategy

Instead of two separate forward passes, we double the batch size: half conditional, half unconditional. This is more GPU-efficient.

27Concatenate Inputs

Stack the same noisy images twice with different labels. The model processes both in parallel.

30Single Forward Pass

One forward pass processes 2*batch_size samples, getting both conditional and unconditional predictions simultaneously.

33Split Results

After the forward pass, split the output back into conditional and unconditional predictions.

47 lines without explanation
1import torch
2import torch.nn as nn
3
4@torch.no_grad()
5def cfg_sample_batched(
6    model: nn.Module,
7    class_labels: torch.Tensor,  # [batch_size]
8    guidance_scale: float = 7.5,
9    num_steps: int = 50,
10    null_class: int = 1000,
11    device: str = "cuda",
12) -> torch.Tensor:
13    """Batched CFG sampling - more efficient memory usage."""
14
15    batch_size = class_labels.shape[0]
16    img_shape = (batch_size, 3, 64, 64)
17
18    # Start from noise
19    x = torch.randn(img_shape, device=device)
20
21    # Create doubled batch: [cond_labels, uncond_labels]
22    null_labels = torch.full_like(class_labels, null_class)
23    doubled_labels = torch.cat([class_labels, null_labels], dim=0)
24
25    # Noise schedule (simplified for demonstration)
26    timesteps = torch.linspace(num_steps - 1, 0, num_steps, device=device).long()
27
28    for t in timesteps:
29        t_batch = t.expand(batch_size)
30
31        # Double the batch: [x, x] for [cond, uncond]
32        x_doubled = torch.cat([x, x], dim=0)
33        t_doubled = torch.cat([t_batch, t_batch], dim=0)
34
35        # Single forward pass for both predictions
36        eps_doubled = model(x_doubled, t_doubled, doubled_labels)
37
38        # Split predictions
39        eps_cond, eps_uncond = eps_doubled.chunk(2, dim=0)
40
41        # Apply CFG
42        eps_guided = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
43
44        # Simplified DDIM step (for demonstration)
45        alpha_t = get_alpha(t.item())
46        alpha_prev = get_alpha(t.item() - 1) if t > 0 else 1.0
47
48        x_0_pred = (x - torch.sqrt(1 - alpha_t) * eps_guided) / torch.sqrt(alpha_t)
49        x = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev) * eps_guided
50
51    return x
ApproachForward PassesMemoryLatency
Naive (sequential)2 per step1x batch2x inference time
Batched (parallel)1 per step2x batch~1x inference time
Production Tip: Most production systems use the batched approach. The memory overhead is acceptable, and the latency improvement is significant (nearly 2x faster).

Tuning the Guidance Scale

Choosing the right guidance scale is crucial for getting good results. Here are practical guidelines:

Scale by Application

ApplicationRecommended wReason
Class-conditional ImageNet2-4Limited condition complexity
Text-to-image (general)7-8Balance quality and diversity
Text-to-image (specific)10-12Strong prompt adherence
Inpainting5-7Must blend with context
Image editing3-5Preserve original content

Scale by Prompt Type

  • Simple prompts ("a cat"): Lower scale (5-7) to avoid oversaturation
  • Detailed prompts (specific styles, attributes): Higher scale (8-12) to capture details
  • Negative prompts: Same scale works; the negative is built into the conditioning

Dynamic Guidance

Some advanced techniques vary the guidance scale during sampling:

  • Linear decay: Start high, decrease toward the end
  • Cosine schedule: Smooth transition from high to low
  • Per-resolution: Higher at low resolution, lower at high

Complete Working Example

Here's a summary of all the pieces working together:

  1. Model architecture: U-Net with class embedding (num_classes + 1 for null)
  2. Training: Standard diffusion loss with 10% label dropout
  3. Sampling: Batched CFG with guidance scale 7.5
  4. Output: High-quality conditional samples

Common Issues and Fixes

  • Samples look unconditional: Check that null class is correctly different from real classes
  • Oversaturated colors: Reduce guidance scale
  • Low diversity: Increase guidance scale or add noise
  • Artifacts at edges: Try lower guidance or different sampler

Key Takeaways

  1. Training requires label dropout: Randomly replace conditions with null during training
  2. Sampling combines two predictions: Conditional and unconditional, weighted by guidance scale
  3. Batch for efficiency: Double the batch to avoid two separate forward passes
  4. Tune guidance scale: 7-8 for general use, adjust based on application and prompt
  5. Works with any sampler: DDPM, DDIM, DPM++, etc.
Looking Ahead: In the next chapter, we'll explore how to extend these conditioning techniques to text-to-image, using text encoders like CLIP and cross-attention mechanisms.