Chapter 8
20 min read
Section 41 of 76

Classifier Guidance

Conditional Generation

Learning Objectives

By the end of this section, you will be able to:

  1. Derive classifier guidance from Bayes' rule and score function decomposition
  2. Explain how the classifier gradient steers the diffusion sampling process
  3. Train a noise-aware classifier for guidance
  4. Implement classifier-guided sampling
  5. Understand the limitations that motivated classifier-free guidance

The Bayesian Perspective

Classifier guidance takes a different approach than training a conditional model. Instead, we use Bayes' rule to combine an unconditional diffusion model with a separate classifier:

p(xc)=p(cx)p(x)p(c)p(\mathbf{x}|c) = \frac{p(c|\mathbf{x}) \cdot p(\mathbf{x})}{p(c)}

Key Insight: We can sample from the conditional distribution p(xc)p(\mathbf{x}|c) by combining an unconditional model p(x)p(\mathbf{x}) with a classifier p(cx)p(c|\mathbf{x}). No need to retrain the diffusion model!

Why This Matters

  • Modularity: Train the diffusion model once, then guide with any classifier
  • Flexibility: Different classifiers for different tasks without retraining
  • Interpretability: The classifier provides explicit control signal

Score Function Decomposition

To implement guidance, we need to work with score functions. The score is the gradient of the log probability:

xlogp(x)\nabla_{\mathbf{x}} \log p(\mathbf{x})

Deriving the Conditional Score

Taking the gradient of log of Bayes' rule:

xlogp(xc)=xlogp(cx)+xlogp(x)\nabla_{\mathbf{x}} \log p(\mathbf{x}|c) = \nabla_{\mathbf{x}} \log p(c|\mathbf{x}) + \nabla_{\mathbf{x}} \log p(\mathbf{x})

Note that xlogp(c)=0\nabla_{\mathbf{x}} \log p(c) = 0 because p(c)p(c) doesn't depend on x\mathbf{x}.

Score Decomposition

The conditional score decomposes beautifully:

xlogp(xc)conditional score=xlogp(cx)classifier gradient+xlogp(x)unconditional score\underbrace{\nabla_{\mathbf{x}} \log p(\mathbf{x}|c)}_{\text{conditional score}} = \underbrace{\nabla_{\mathbf{x}} \log p(c|\mathbf{x})}_{\text{classifier gradient}} + \underbrace{\nabla_{\mathbf{x}} \log p(\mathbf{x})}_{\text{unconditional score}}

This tells us: to sample conditionally, follow the unconditional score plus the classifier gradient!

Connection to Noise Prediction

Recall that the noise prediction ϵθ\boldsymbol{\epsilon}_\thetais related to the score:

ϵθ(xt,t)1αˉtxtlogp(xt)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \approx -\sqrt{1 - \bar{\alpha}_t} \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)

Therefore, the guided noise prediction becomes:

ϵ~=ϵθ1αˉtxtlogp(cxt)\tilde{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_\theta - \sqrt{1 - \bar{\alpha}_t} \nabla_{\mathbf{x}_t} \log p(c|\mathbf{x}_t)


The Guidance Scale

In practice, we introduce a guidance scale ww to control the strength of conditioning:

ϵ~=ϵθw1αˉtxtlogp(cxt)\tilde{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_\theta - w \cdot \sqrt{1 - \bar{\alpha}_t} \cdot \nabla_{\mathbf{x}_t} \log p(c|\mathbf{x}_t)

Interpreting the Scale

Scale wEffectInterpretation
w = 0Pure unconditionalClassifier has no influence
w = 1Standard BayesExact conditional sampling
w > 1Amplified guidancePush harder toward condition (may reduce diversity)
w < 0Negative guidancePush away from condition (useful for avoiding classes)
Quality-Diversity Trade-off: Higher guidance scales produce samples that more strongly match the condition, but reduce diversity. Very high scales can cause artifacts as the model is pushed outside its learned distribution.

Training a Noisy Classifier

A critical requirement: the classifier must work on noisy imagesat all noise levels, not just clean images. A standard ImageNet classifier won't work because it was trained on clean data.

Why Noisy Classification is Hard

  • At high noise levels (large t), the image is nearly pure noise - classification seems impossible
  • The classifier must learn to extract whatever signal remains at each noise level
  • Gradients from the classifier must be meaningful at all timesteps
Noise-Aware Classifier
🐍noisy_classifier.py
7Noisy Classifier

Unlike standard classifiers that see clean images, this classifier is trained on noisy images at all noise levels. It must learn to extract class information even when the image is heavily corrupted.

11Time Embedding

The classifier needs to know the noise level (timestep) to interpret the noisy image correctly. An image at t=900 is almost pure noise, while t=100 is nearly clean.

31Time-Conditioned Head

We concatenate image features with the time embedding before classification. This allows the classifier to adapt its decision boundary based on noise level.

38Standard Classifier Output

The output is standard logits over classes. We can compute gradients of log p(c|x_t) with respect to x_t for guidance.

50 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class NoisyClassifier(nn.Module):
6    """
7    A classifier trained on noisy images at all timesteps.
8    Must handle the same noise levels as the diffusion model.
9    """
10    def __init__(self, num_classes: int, img_channels: int = 3):
11        super().__init__()
12        # Time embedding (same as diffusion model)
13        self.time_embed = nn.Sequential(
14            nn.Linear(256, 512),
15            nn.SiLU(),
16            nn.Linear(512, 512),
17        )
18
19        # Simple CNN backbone (could be ResNet, etc.)
20        self.backbone = nn.Sequential(
21            nn.Conv2d(img_channels, 64, 3, padding=1),
22            nn.GroupNorm(8, 64),
23            nn.SiLU(),
24            nn.Conv2d(64, 128, 3, stride=2, padding=1),
25            nn.GroupNorm(8, 128),
26            nn.SiLU(),
27            # ... more layers ...
28            nn.AdaptiveAvgPool2d(1),
29            nn.Flatten(),
30        )
31
32        # Classification head with time conditioning
33        self.head = nn.Sequential(
34            nn.Linear(128 + 512, 256),  # Features + time embedding
35            nn.SiLU(),
36            nn.Linear(256, num_classes),
37        )
38
39    def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
40        # x_t: noisy image [B, C, H, W]
41        # t: timestep [B]
42
43        # Embed timestep
44        t_emb = self.time_embed(sinusoidal_embedding(t))
45
46        # Extract features from noisy image
47        features = self.backbone(x_t)  # [B, 128]
48
49        # Concatenate with time embedding
50        combined = torch.cat([features, t_emb], dim=1)  # [B, 128 + 512]
51
52        # Classify
53        logits = self.head(combined)  # [B, num_classes]
54        return logits

Training the Noisy Classifier

Training is similar to the diffusion model:

  1. Sample clean images and labels from dataset
  2. Sample random timesteps
  3. Add noise according to the forward process
  4. Train classifier to predict class from noisy image + timestep

Classifier Architecture

Any classification architecture works (ResNet, ViT, etc.), but it must:
  • Accept timestep as an additional input
  • Be trained on noisy images at all noise levels
  • Have gradients that are well-behaved for guidance

Classifier-Guided Sampling

Here's the complete classifier-guided sampling algorithm:

Classifier-Guided DDPM Sampling
🐍classifier_guided_sampling.py
8Target Class

The class we want to generate. The classifier gradient will push samples toward images that this classifier believes are this class.

9Guidance Scale

Controls how strongly the classifier influences generation. w=0 is unconditional, w=1 is standard guidance, w>1 amplifies the effect.

27Unconditional Score

First, get the standard noise prediction from the diffusion model. This represents the unconditional score.

31Classifier Gradient

Enable gradients to compute d/dx log p(c|x_t). This gradient points toward regions of x-space that the classifier associates with class c.

35Log Probability

We need log p(c|x_t), not just the logits. Log softmax gives us the log probabilities.

41Guided Noise Prediction

The key equation: eps_guided = eps_uncond - w * sqrt(1 - alpha_bar) * grad. The negative sign comes from the relationship between epsilon and score.

56 lines without explanation
1import torch
2import torch.nn.functional as F
3
4@torch.no_grad()
5def classifier_guided_sample(
6    diffusion_model: nn.Module,
7    classifier: nn.Module,
8    target_class: int,
9    num_steps: int = 1000,
10    guidance_scale: float = 1.0,
11    img_shape: tuple = (1, 3, 64, 64),
12    device: str = "cuda",
13) -> torch.Tensor:
14    """
15    Sample with classifier guidance.
16    """
17    # Start from pure noise
18    x = torch.randn(img_shape, device=device)
19
20    # Noise schedule
21    betas = get_beta_schedule(num_steps)
22    alphas = 1 - betas
23    alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
24
25    for i in reversed(range(num_steps)):
26        t = torch.tensor([i], device=device)
27
28        # Get unconditional score from diffusion model
29        with torch.no_grad():
30            eps_uncond = diffusion_model(x, t)
31
32        # Get classifier gradient (with gradients enabled)
33        with torch.enable_grad():
34            x_in = x.detach().requires_grad_(True)
35            logits = classifier(x_in, t)
36            log_probs = F.log_softmax(logits, dim=-1)
37            selected = log_probs[:, target_class].sum()
38            grad = torch.autograd.grad(selected, x_in)[0]
39
40        # Apply classifier guidance to the score
41        # score = -eps / sqrt(1 - alpha_bar)
42        # guided_score = score + w * grad_x log p(c|x)
43        alpha_bar_t = alphas_cumprod[i]
44        eps_guided = eps_uncond - guidance_scale * torch.sqrt(1 - alpha_bar_t) * grad
45
46        # Standard DDPM update with guided noise prediction
47        if i > 0:
48            noise = torch.randn_like(x)
49            alpha_t = alphas[i]
50            alpha_bar_prev = alphas_cumprod[i - 1]
51            # Posterior mean
52            coef1 = betas[i] * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar_t)
53            coef2 = (1 - alpha_bar_prev) * torch.sqrt(alpha_t) / (1 - alpha_bar_t)
54            x_0_pred = (x - torch.sqrt(1 - alpha_bar_t) * eps_guided) / torch.sqrt(alpha_bar_t)
55            mean = coef1 * x_0_pred + coef2 * x
56            var = betas[i] * (1 - alpha_bar_prev) / (1 - alpha_bar_t)
57            x = mean + torch.sqrt(var) * noise
58        else:
59            x_0_pred = (x - torch.sqrt(1 - alpha_bar_t) * eps_guided) / torch.sqrt(alpha_bar_t)
60            x = x_0_pred
61
62    return x

Algorithm Summary

  1. Start with noise xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(0, I)
  2. For each timestep t = T, ..., 1:
    • Get unconditional noise prediction ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)
    • Compute classifier gradient xtlogp(cxt)\nabla_{\mathbf{x}_t} \log p(c|\mathbf{x}_t)
    • Combine: ϵ~=ϵθw1αˉtlogp(cxt)\tilde{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_\theta - w \sqrt{1-\bar{\alpha}_t} \nabla \log p(c|\mathbf{x}_t)
    • Take DDPM step using guided noise prediction
  3. Return x0\mathbf{x}_0

Limitations and Trade-offs

While elegant, classifier guidance has significant practical limitations:

1. Requires Separate Classifier

  • Must train and maintain an additional model
  • Classifier must be trained on noisy images at all timesteps
  • Storage and compute costs are doubled

2. Gradient Quality Issues

  • Classifier gradients can be noisy, especially at high noise levels
  • Gradients may not align well with image quality
  • Adversarial-like artifacts can appear

3. Limited Conditioning Types

  • Works well for classification, but text conditioning is harder
  • Need a separate classifier for each conditioning type
  • No elegant way to do text-to-image with this approach
AspectClassifier GuidanceAlternative Needed
Models required2 (diffusion + classifier)Ideally 1
TrainingSeparate for eachJoint training
Gradient qualityCan be noisyImplicit, smoother
Text conditioningRequires text classifierNatural language handling
Motivation for CFG: These limitations led to Classifier-Free Guidance (next section), which eliminates the need for a separate classifier by training the diffusion model itself to handle both conditional and unconditional generation.

Key Takeaways

  1. Bayes' rule lets us factor conditional generation as p(xc)p(cx)p(x)p(\mathbf{x}|c) \propto p(c|\mathbf{x}) \cdot p(\mathbf{x})
  2. Score decomposition: The conditional score equals unconditional score plus classifier gradient
  3. Guidance scale ww controls conditioning strength (higher = more adherence, less diversity)
  4. Noisy classifier must be trained on images at all noise levels with timestep conditioning
  5. Limitations: Requires separate model, gradient quality issues, poor for text conditioning
Looking Ahead: The next section introduces Classifier-Free Guidance (CFG), which elegantly solves these problems by training a single model that can do both conditional and unconditional generation.