Chapter 8
18 min read
Section 40 of 76

Class-Conditional Diffusion

Conditional Generation

Learning Objectives

By the end of this section, you will be able to:

  1. Implement learnable class embedding tables for conditioning on discrete labels
  2. Apply label dropout during training to enable classifier-free guidance
  3. Build Adaptive Group Normalization (AdaGN) layers for powerful conditioning
  4. Train a class-conditional diffusion model from scratch

Class Embeddings

The simplest form of conditioning uses discrete class labels. Given a class index (e.g., 281 for "cat" in ImageNet), we need to convert it into a continuous embedding that the neural network can process.

Embedding Tables

An embedding table is a learnable matrix where each row corresponds to one class:

ERK×d\mathbf{E} \in \mathbb{R}^{K \times d}

where KK is the number of classes and dd is the embedding dimension. Looking up class cc simply returns row cc:

ec=E[c,:]Rd\mathbf{e}_c = \mathbf{E}[c, :] \in \mathbb{R}^d

Learnable Class Embeddings
🐍class_embedding.py
10Embedding Table

nn.Embedding creates a lookup table with num_classes rows, each containing an embed_dim dimensional vector. During training, these vectors are learned to represent each class.

EXAMPLE
For ImageNet with 1000 classes and 512-dim embeddings: 1000 x 512 = 512K learnable parameters
13Projection Network

An optional MLP that transforms the embedding. This adds expressivity and can help align the class embedding with the time embedding space.

20Lookup Operation

The embedding layer performs a simple lookup: given class index 281, it returns row 281 of the table. This is differentiable and the embeddings are updated during backpropagation.

27Embedding Dimension

512 is a common choice that balances expressivity with computational cost. This should match your time embedding dimension for easy combination.

29 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class ClassEmbedding(nn.Module):
6    """
7    Learnable embedding table for class labels.
8    Each class gets its own embedding vector.
9    """
10    def __init__(self, num_classes: int, embed_dim: int):
11        super().__init__()
12        # Embedding table: num_classes x embed_dim
13        self.embedding = nn.Embedding(num_classes, embed_dim)
14
15        # Optional: project to match time embedding dimension
16        self.proj = nn.Sequential(
17            nn.Linear(embed_dim, embed_dim),
18            nn.SiLU(),
19            nn.Linear(embed_dim, embed_dim),
20        )
21
22    def forward(self, class_labels: torch.Tensor) -> torch.Tensor:
23        # class_labels: [batch_size] of integers 0..num_classes-1
24        emb = self.embedding(class_labels)  # [B, embed_dim]
25        return self.proj(emb)
26
27# Example usage
28num_classes = 1000  # ImageNet
29embed_dim = 512
30
31class_embed = ClassEmbedding(num_classes, embed_dim)
32labels = torch.tensor([281, 0, 999])  # cat, tench, toilet tissue
33embeddings = class_embed(labels)  # [3, 512]

Why Learn Embeddings?

Unlike one-hot encoding (which is sparse and doesn't capture relationships), learned embeddings can capture semantic similarities. If "cat" and "dog" are both pets, their embeddings might be closer together than "cat" and "airplane".

Embedding Dimension Choice

DimensionTrade-offTypical Use
128-256Compact, less expressiveSmall models, few classes
512Good balanceMost diffusion models
768-1024More expressive, costlyLarge-scale models

Label Dropout for CFG Training

To enable Classifier-Free Guidance (which we'll cover in Section 8.4), we need to train the model to work both conditionally and unconditionally. This is achieved through label dropout.

Key Insight: During training, we randomly replace some class labels with a special "null" token. This teaches the model to generate without conditions when needed.

The Null Embedding

We add one extra class to represent "no condition":

  • Classes 0 to K-1: Real class labels
  • Class K: The "null" or unconditional class

During training, with probability pdropp_{\text{drop}}, we replace the true label with the null class. This is equivalent to training on a mixture:

L=(1pdrop)Lcond+pdropLuncond\mathcal{L} = (1 - p_{\text{drop}}) \cdot \mathcal{L}_{\text{cond}} + p_{\text{drop}} \cdot \mathcal{L}_{\text{uncond}}

Label Dropout Implementation
🐍label_dropout.py
9Dropout Probability

dropout_prob controls how often we drop the condition during training. 10-20% is typical. Higher values make the unconditional model stronger but may hurt conditional quality.

17Null Class Index

We add one extra embedding for the 'unconditional' case. When the label is dropped, we use this null embedding instead. The model learns to associate this embedding with unconditional generation.

33Training Dropout

During training, we randomly replace some labels with the null class. This trains the model to work both conditionally and unconditionally.

39Force Drop Mode

During inference with CFG, we need to explicitly get the unconditional prediction. force_drop=True replaces all labels with the null class.

47 lines without explanation
1import torch
2import torch.nn as nn
3
4class ConditionalUNet(nn.Module):
5    """
6    U-Net with label dropout for classifier-free guidance training.
7    """
8    def __init__(
9        self,
10        num_classes: int,
11        embed_dim: int,
12        dropout_prob: float = 0.1,  # Probability of dropping condition
13    ):
14        super().__init__()
15        self.num_classes = num_classes
16        self.dropout_prob = dropout_prob
17
18        # Class embedding (include null class)
19        # Class indices: 0..num_classes-1 are real classes
20        # Index num_classes is the "null" or "unconditional" class
21        self.class_embed = nn.Embedding(num_classes + 1, embed_dim)
22        self.null_class = num_classes  # Index for null embedding
23
24        # ... rest of U-Net ...
25
26    def forward(
27        self,
28        x: torch.Tensor,
29        t: torch.Tensor,
30        class_labels: torch.Tensor,
31        force_drop: bool = False,  # For unconditional generation
32    ) -> torch.Tensor:
33        batch_size = x.shape[0]
34
35        # During training: randomly drop labels
36        if self.training:
37            # Create dropout mask
38            drop_mask = torch.rand(batch_size, device=x.device) < self.dropout_prob
39            # Replace dropped labels with null class
40            labels = torch.where(drop_mask, self.null_class, class_labels)
41        elif force_drop:
42            # Force unconditional (for CFG inference)
43            labels = torch.full_like(class_labels, self.null_class)
44        else:
45            labels = class_labels
46
47        # Get class embedding
48        c_emb = self.class_embed(labels)
49
50        # Continue with U-Net forward pass...
51        return self.unet_forward(x, t, c_emb)

Choosing Dropout Probability

p_dropEffectRecommended Use
0.05Weak unconditional, strong conditionalWhen CFG is rarely used
0.1Good balanceStandard choice
0.2Stronger unconditionalWhen diversity matters
0.5Equal training of bothMaximum flexibility

AdaGN Conditioning Mechanism

Adaptive Group Normalization (AdaGN) is a powerful way to inject conditioning information throughout the network. Instead of fixed normalization parameters, the scale and shift are predicted from the condition.

Standard Group Normalization

Regular GroupNorm normalizes features and applies learned affine transformation:

GN(h)=γhμσ+β\text{GN}(\mathbf{h}) = \gamma \cdot \frac{\mathbf{h} - \mu}{\sigma} + \beta

where γ,β\gamma, \beta are fixed learned parameters.

Adaptive Group Normalization

AdaGN makes the scale and shift dynamic - predicted from the conditioning embedding:

AdaGN(h,c)=γ(c)hμσ+β(c)\text{AdaGN}(\mathbf{h}, \mathbf{c}) = \gamma(\mathbf{c}) \cdot \frac{\mathbf{h} - \mu}{\sigma} + \beta(\mathbf{c})

where γ(c),β(c)\gamma(\mathbf{c}), \beta(\mathbf{c}) are outputs of a small neural network that takes the condition embedding as input.

Adaptive Group Normalization
🐍adagn.py
13Group Norm without Affine

We use GroupNorm with affine=False because we want the scale and shift to come from the conditioning embedding, not learned per-channel parameters.

16Scale and Shift Prediction

A linear layer predicts both scale (gamma) and shift (beta) from the embedding. The output is 2*num_channels: half for scale, half for shift.

25Chunk Operation

Split the predicted vector into scale and shift components. Each has shape [B, C].

32Adaptive Modulation

The formula is: output = normalized * (1 + scale) + shift. The (1 + scale) allows the scale to be centered around 1, making identity the default.

38ResBlock with AdaGN

A complete residual block using AdaGN. The conditioning embedding modulates both normalization layers, giving the condition strong influence over the features.

51 lines without explanation
1import torch
2import torch.nn as nn
3
4class AdaGroupNorm(nn.Module):
5    """
6    Adaptive Group Normalization.
7    The condition modulates the normalized features via learned scale and shift.
8    """
9    def __init__(
10        self,
11        num_channels: int,
12        num_groups: int,
13        embed_dim: int,
14    ):
15        super().__init__()
16        self.gn = nn.GroupNorm(num_groups, num_channels, affine=False)
17
18        # Learn scale (gamma) and shift (beta) from embedding
19        self.scale_shift = nn.Linear(embed_dim, num_channels * 2)
20
21    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
22        # x: [B, C, H, W], emb: [B, embed_dim]
23
24        # Normalize
25        x = self.gn(x)
26
27        # Get scale and shift from embedding
28        scale_shift = self.scale_shift(emb)  # [B, C*2]
29        scale, shift = scale_shift.chunk(2, dim=1)  # Each [B, C]
30
31        # Reshape for broadcasting
32        scale = scale[:, :, None, None]  # [B, C, 1, 1]
33        shift = shift[:, :, None, None]
34
35        # Apply adaptive normalization
36        return x * (1 + scale) + shift
37
38
39class ResBlockWithAdaGN(nn.Module):
40    """ResBlock that uses AdaGN for conditioning."""
41    def __init__(self, channels: int, embed_dim: int, num_groups: int = 32):
42        super().__init__()
43        self.norm1 = AdaGroupNorm(channels, num_groups, embed_dim)
44        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
45        self.norm2 = AdaGroupNorm(channels, num_groups, embed_dim)
46        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
47        self.act = nn.SiLU()
48
49    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
50        h = self.norm1(x, emb)
51        h = self.act(h)
52        h = self.conv1(h)
53        h = self.norm2(h, emb)
54        h = self.act(h)
55        h = self.conv2(h)
56        return x + h

Why AdaGN Works So Well

  • Global modulation: Every feature channel can be scaled/shifted based on condition
  • Lightweight: Only adds a linear layer per norm
  • Powerful: Can completely transform the feature distribution based on class

Combining Time and Class Conditioning

In practice, we typically add the time and class embeddings before feeding to AdaGN:

ecombined=et+ec\mathbf{e}_{\text{combined}} = \mathbf{e}_t + \mathbf{e}_c

This combined embedding is then used in all AdaGN layers throughout the U-Net.


Training Procedure

Training a class-conditional diffusion model is almost identical to unconditional training, with a few key differences:

  1. Include labels in dataloader: Each batch contains (images, class_labels)
  2. Apply label dropout: Randomly replace labels with null during training
  3. Condition the model: Pass labels through embedding and inject via AdaGN
Class-Conditional Training Loop
🐍train_conditional.py
13Alpha Schedule

Pre-compute the cumulative product of alphas. These define the noise schedule and are used to create noisy images at each timestep.

18Load Labels

Unlike unconditional training, we now need the class labels from the dataloader. These will be used for conditional noise prediction.

23Sample Timesteps

Random timesteps are sampled uniformly. At each timestep, we corrupt the image and train the model to predict the added noise.

30Create Noisy Images

The standard forward diffusion formula: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon

33Conditional Prediction

The model predicts noise conditioned on both timestep and class label. Label dropout happens inside the model during training.

36Simple MSE Loss

The loss is the same as unconditional: predict the noise. The conditioning just changes what information the model has access to.

38 lines without explanation
1import torch
2import torch.nn.functional as F
3from torch.utils.data import DataLoader
4
5def train_class_conditional(
6    model: nn.Module,
7    dataloader: DataLoader,
8    optimizer: torch.optim.Optimizer,
9    num_timesteps: int = 1000,
10    device: str = "cuda",
11):
12    """Training loop for class-conditional diffusion."""
13    model.train()
14    alphas_cumprod = get_alphas_cumprod(num_timesteps).to(device)
15
16    for epoch in range(num_epochs):
17        for batch in dataloader:
18            images, labels = batch
19            images = images.to(device)
20            labels = labels.to(device)
21            batch_size = images.shape[0]
22
23            # Sample random timesteps
24            t = torch.randint(0, num_timesteps, (batch_size,), device=device)
25
26            # Sample noise
27            noise = torch.randn_like(images)
28
29            # Get noisy images
30            alpha_t = alphas_cumprod[t][:, None, None, None]
31            x_t = torch.sqrt(alpha_t) * images + torch.sqrt(1 - alpha_t) * noise
32
33            # Predict noise (with label dropout happening inside model)
34            noise_pred = model(x_t, t, labels)
35
36            # MSE loss
37            loss = F.mse_loss(noise_pred, noise)
38
39            # Backprop
40            optimizer.zero_grad()
41            loss.backward()
42            optimizer.step()
43
44    return model

Complete Implementation

Here's how all the pieces fit together in a complete class-conditional U-Net architecture:

ComponentPurposeKey Parameters
ClassEmbeddingConvert label to vectornum_classes + 1, embed_dim
SinusoidalEmbeddingEncode timestepmax_period, embed_dim
EmbeddingMLPProject embeddingsembed_dim -> embed_dim
AdaGNCondition normalizationchannels, num_groups, embed_dim
ResBlockAdaGNConditioned residual blockin_ch, out_ch, embed_dim
AttentionBlockSelf-attention (optional)channels, num_heads

Architecture Choices

  • Embedding dimension: 512-1024 for most models
  • Number of groups: 32 groups for GroupNorm is standard
  • Label dropout: 0.1 is a safe default
  • Attention: Only at lower resolutions (16x16, 8x8)

Key Takeaways

  1. Class embeddings convert discrete labels to continuous vectors via learnable lookup tables
  2. Label dropout randomly replaces labels with a null token, training the model for both conditional and unconditional generation
  3. AdaGN allows the condition to modulate normalization scale and shift throughout the network
  4. Time and class embeddings are typically added together and used jointly in AdaGN layers
  5. Training is nearly identical to unconditional, just with labels included and dropout applied
Looking Ahead: Now that we can train class-conditional models, the next section introduces Classifier Guidance - an alternative approach that uses a separate classifier to steer generation.