Chapter 11
15 min read
Section 51 of 76

Model Configuration

Training the Model

Learning Objectives

By the end of this section, you will be able to:

  1. Configure channel multipliers to control model capacity at different resolutions
  2. Choose attention resolutions that balance quality and computational cost
  3. Select appropriate timestep counts and noise schedules for your dataset
  4. Size your model appropriately for your GPU memory and training budget

The Big Picture

A diffusion model's architecture determines both its capacity to learn complex patterns and its computational requirements. The same U-Net backbone can range from a 5M parameter model trainable on a laptop to a 1B+ parameter model requiring a cluster of A100s.

The Key Insight: Model configuration is about finding the sweet spot between model capacity and your available compute. A larger model isn't always better - training a properly-sized model longer often beats training an oversized model briefly.

In this section, we'll learn how to configure each component of the U-Net architecture to match your computational resources and quality targets.


Channel Multipliers

Channel multipliers control the width of the network at each resolution level. The U-Net processes images at progressively lower resolutions, with the channel count typically increasing as resolution decreases.

Understanding Channel Progression

For a base channel count of C=64C = 64 with multipliers[1,2,4,8][1, 2, 4, 8], the channel counts at each level are:

LevelResolution (64x64 input)MultiplierChannels
064x64164
132x322128
216x164256
38x88512
🐍python
1from dataclasses import dataclass
2from typing import Tuple, List
3
4@dataclass
5class UNetConfig:
6    """Configuration for diffusion U-Net architecture."""
7
8    # Input/output
9    image_size: int = 64
10    in_channels: int = 3
11    out_channels: int = 3
12
13    # Architecture
14    base_channels: int = 64
15    channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8)
16
17    # Attention
18    attention_resolutions: Tuple[int, ...] = (16, 8)
19    num_attention_heads: int = 4
20
21    # Blocks
22    num_res_blocks: int = 2
23    dropout: float = 0.0
24
25    # Time embedding
26    time_embed_dim: int = None  # Defaults to 4 * base_channels
27
28    def __post_init__(self):
29        if self.time_embed_dim is None:
30            self.time_embed_dim = self.base_channels * 4
31
32    @property
33    def channel_counts(self) -> List[int]:
34        """Get channel counts at each resolution level."""
35        return [self.base_channels * m for m in self.channel_multipliers]
36
37    @property
38    def num_parameters(self) -> int:
39        """Rough estimate of parameter count (millions)."""
40        # This is a simplified estimate
41        base = self.base_channels ** 2 * sum(m**2 for m in self.channel_multipliers)
42        return int(base * self.num_res_blocks * 0.01)  # Very rough
43
44
45# Example configurations
46small_config = UNetConfig(
47    base_channels=32,
48    channel_multipliers=(1, 2, 4),
49)
50
51medium_config = UNetConfig(
52    base_channels=64,
53    channel_multipliers=(1, 2, 4, 8),
54)
55
56large_config = UNetConfig(
57    base_channels=128,
58    channel_multipliers=(1, 2, 4, 8),
59)
60
61print(f"Small: {small_config.channel_counts}")   # [32, 64, 128]
62print(f"Medium: {medium_config.channel_counts}") # [64, 128, 256, 512]
63print(f"Large: {large_config.channel_counts}")   # [128, 256, 512, 1024]

Guidelines for Channel Selection

Choosing channel counts depends on your image resolution and compute budget:

Image SizeBase ChannelsMultipliersEst. Params
28x28 (MNIST)32(1, 2, 4)~2M
32x32 (CIFAR)64(1, 2, 4)~15M
64x6464(1, 2, 4, 8)~35M
128x12896(1, 2, 4, 8)~100M
256x256128(1, 2, 3, 4)~200M
512x512160(1, 2, 2, 4)~400M

Attention Resolution

Self-attention is the most expensive operation in the U-Net. We typically only apply it at lower resolutions where the feature maps are small:

🐍python
1def should_use_attention(
2    resolution: int,
3    attention_resolutions: Tuple[int, ...],
4) -> bool:
5    """
6    Determine if attention should be used at this resolution.
7
8    Args:
9        resolution: Current feature map resolution (e.g., 32, 16, 8)
10        attention_resolutions: Resolutions where attention is applied
11
12    Returns:
13        True if attention should be used
14    """
15    return resolution in attention_resolutions
16
17
18# Cost analysis of attention
19def attention_memory_cost(resolution: int, channels: int, batch_size: int) -> float:
20    """
21    Estimate memory cost of self-attention in GB.
22
23    Self-attention has O(n^2) memory where n = resolution^2
24    """
25    seq_len = resolution ** 2  # Number of spatial positions
26    # Q, K, V matrices: 3 * batch * seq * channels * 4 bytes
27    qkv_memory = 3 * batch_size * seq_len * channels * 4
28    # Attention scores: batch * heads * seq * seq * 4 bytes
29    attn_memory = batch_size * seq_len * seq_len * 4
30
31    total_bytes = qkv_memory + attn_memory
32    return total_bytes / (1024 ** 3)  # Convert to GB
33
34
35# Example: Why we avoid attention at high resolutions
36for res in [64, 32, 16, 8]:
37    cost = attention_memory_cost(res, channels=256, batch_size=32)
38    print(f"Resolution {res}x{res}: {cost:.2f} GB")
39
40# Output:
41# Resolution 64x64: 4.00 GB
42# Resolution 32x32: 0.25 GB
43# Resolution 16x16: 0.02 GB
44# Resolution 8x8: 0.00 GB

Attention Trade-offs

More attention improves quality but increases memory and time quadratically with resolution. For 256x256 images, attention at 32x32 and below is typical. At 64x64 resolution is usually too expensive.

Attention Configuration Patterns

🐍python
1# Common attention patterns for different image sizes
2
3def get_attention_config(image_size: int) -> Tuple[int, ...]:
4    """
5    Get recommended attention resolutions for image size.
6
7    Returns resolutions where attention should be applied.
8    """
9    if image_size <= 32:
10        # Small images: attention at all levels
11        return (16, 8, 4) if image_size == 32 else (8, 4)
12
13    elif image_size == 64:
14        # Medium images: attention at 16x16 and below
15        return (16, 8)
16
17    elif image_size == 128:
18        # Larger images: attention at 32x32 and below
19        return (32, 16, 8)
20
21    elif image_size == 256:
22        # Large images: attention at 32x32 and below only
23        return (32, 16)
24
25    else:
26        # Very large: be conservative
27        return (32, 16)
28
29
30# Number of attention heads
31def get_num_heads(channels: int) -> int:
32    """
33    Determine number of attention heads based on channel count.
34
35    Each head should have at least 32-64 dimensions.
36    """
37    head_dim = 64  # Typical head dimension
38    num_heads = max(1, channels // head_dim)
39    return num_heads
40
41print(f"256 channels -> {get_num_heads(256)} heads")  # 4 heads
42print(f"512 channels -> {get_num_heads(512)} heads")  # 8 heads

Timesteps and Schedule

The number of diffusion timesteps and the noise schedule affect both training stability and generation quality:

🐍python
1import torch
2import numpy as np
3
4@dataclass
5class DiffusionConfig:
6    """Configuration for diffusion process."""
7
8    # Number of timesteps
9    timesteps: int = 1000
10
11    # Noise schedule type
12    schedule: str = "linear"  # "linear", "cosine", "sqrt"
13
14    # Schedule parameters
15    beta_start: float = 0.0001
16    beta_end: float = 0.02
17
18    # Loss weighting
19    loss_type: str = "mse"  # "mse", "l1", "huber"
20    prediction_type: str = "epsilon"  # "epsilon", "v", "x0"
21
22
23def create_beta_schedule(config: DiffusionConfig) -> torch.Tensor:
24    """Create noise schedule based on config."""
25    T = config.timesteps
26
27    if config.schedule == "linear":
28        return torch.linspace(config.beta_start, config.beta_end, T)
29
30    elif config.schedule == "cosine":
31        # Cosine schedule from "Improved Denoising Diffusion" paper
32        s = 0.008
33        steps = torch.linspace(0, T, T + 1)
34        alphas_cumprod = torch.cos((steps / T + s) / (1 + s) * np.pi / 2) ** 2
35        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
36        betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
37        return torch.clamp(betas, 0.0001, 0.9999)
38
39    elif config.schedule == "sqrt":
40        return torch.linspace(config.beta_start**0.5, config.beta_end**0.5, T) ** 2
41
42    else:
43        raise ValueError(f"Unknown schedule: {config.schedule}")
44
45
46# Compare schedules
47linear_config = DiffusionConfig(schedule="linear")
48cosine_config = DiffusionConfig(schedule="cosine")
49
50linear_betas = create_beta_schedule(linear_config)
51cosine_betas = create_beta_schedule(cosine_config)
52
53# Compute alpha_bar (cumulative product of 1-beta)
54linear_alpha_bar = torch.cumprod(1 - linear_betas, dim=0)
55cosine_alpha_bar = torch.cumprod(1 - cosine_betas, dim=0)
56
57print(f"Linear schedule: alpha_bar ranges from {linear_alpha_bar[-1]:.4f} to {linear_alpha_bar[0]:.4f}")
58print(f"Cosine schedule: alpha_bar ranges from {cosine_alpha_bar[-1]:.4f} to {cosine_alpha_bar[0]:.4f}")

Schedule Selection Guidelines

ScheduleBest ForCharacteristics
LinearQuick experimentsSimple, fast to compute, may struggle with details
CosineMost applicationsBetter detail preservation, slower at high noise
SqrtLarge imagesGentler at low noise, good for high-res

The number of timesteps affects the trade-off between quality and training cost:

  • T=1000: Standard for most applications, good quality
  • T=500: Faster training, slight quality reduction
  • T=2000+: Marginal quality gains, 2x compute cost

Model Sizing Guidelines

Choosing the right model size depends on your dataset and compute budget:

🐍python
1@dataclass
2class TrainingBudget:
3    """Estimate training requirements for different configurations."""
4
5    gpu_memory_gb: float
6    batch_size: int
7    training_time_hours: float  # For ~100 epochs on CIFAR-10
8
9
10# Approximate requirements (single GPU)
11budgets = {
12    "tiny": TrainingBudget(gpu_memory_gb=4, batch_size=64, training_time_hours=2),
13    "small": TrainingBudget(gpu_memory_gb=8, batch_size=64, training_time_hours=6),
14    "medium": TrainingBudget(gpu_memory_gb=16, batch_size=128, training_time_hours=12),
15    "large": TrainingBudget(gpu_memory_gb=24, batch_size=64, training_time_hours=24),
16    "xl": TrainingBudget(gpu_memory_gb=40, batch_size=32, training_time_hours=48),
17}
18
19
20def recommend_config(
21    gpu_memory_gb: float,
22    image_size: int,
23    dataset_size: int,
24) -> UNetConfig:
25    """
26    Recommend model configuration based on hardware and data.
27
28    Args:
29        gpu_memory_gb: Available GPU memory
30        image_size: Target image resolution
31        dataset_size: Number of training samples
32
33    Returns:
34        Recommended UNetConfig
35    """
36    # Scale base channels by GPU memory
37    if gpu_memory_gb >= 40:
38        base = 128
39    elif gpu_memory_gb >= 16:
40        base = 64
41    elif gpu_memory_gb >= 8:
42        base = 48
43    else:
44        base = 32
45
46    # Adjust for image size (larger images need more memory for activations)
47    if image_size >= 256:
48        base = min(base, 96)
49    elif image_size >= 128:
50        base = min(base, 128)
51
52    # Select channel multipliers based on resolution
53    if image_size <= 32:
54        mults = (1, 2, 4)
55    elif image_size <= 64:
56        mults = (1, 2, 4, 8)
57    elif image_size <= 128:
58        mults = (1, 2, 3, 4)
59    else:
60        mults = (1, 2, 2, 4)
61
62    # Select attention resolutions
63    attention_res = get_attention_config(image_size)
64
65    return UNetConfig(
66        image_size=image_size,
67        base_channels=base,
68        channel_multipliers=mults,
69        attention_resolutions=attention_res,
70    )
71
72
73# Example
74config = recommend_config(gpu_memory_gb=16, image_size=64, dataset_size=50000)
75print(f"Recommended: {config.base_channels} base channels")
76print(f"Channel progression: {config.channel_counts}")

Configuration Dataclass

Here's a complete configuration class that combines all settings:

🐍python
1from dataclasses import dataclass, field
2from typing import Tuple, Optional
3import json
4
5@dataclass
6class DiffusionModelConfig:
7    """
8    Complete configuration for diffusion model training.
9
10    This dataclass captures all hyperparameters needed to:
11    - Define the U-Net architecture
12    - Configure the diffusion process
13    - Set up training parameters
14    """
15
16    # ===== Data =====
17    image_size: int = 64
18    in_channels: int = 3
19    dataset: str = "cifar10"
20
21    # ===== U-Net Architecture =====
22    base_channels: int = 64
23    channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8)
24    num_res_blocks: int = 2
25    attention_resolutions: Tuple[int, ...] = (16, 8)
26    num_attention_heads: int = 4
27    dropout: float = 0.0
28
29    # ===== Diffusion Process =====
30    timesteps: int = 1000
31    noise_schedule: str = "linear"
32    beta_start: float = 0.0001
33    beta_end: float = 0.02
34    prediction_type: str = "epsilon"  # "epsilon", "v", "x0"
35
36    # ===== Training =====
37    batch_size: int = 64
38    learning_rate: float = 2e-4
39    num_epochs: int = 100
40    warmup_steps: int = 1000
41    gradient_accumulation: int = 1
42    use_ema: bool = True
43    ema_decay: float = 0.9999
44
45    # ===== Optimization =====
46    optimizer: str = "adamw"
47    weight_decay: float = 0.01
48    use_amp: bool = True
49
50    # ===== Checkpointing =====
51    checkpoint_every: int = 10
52    sample_every: int = 5
53    num_samples: int = 16
54
55    def __post_init__(self):
56        """Validate and adjust configuration."""
57        # Ensure attention resolutions are valid
58        current_res = self.image_size
59        valid_res = []
60        for mult in self.channel_multipliers:
61            if current_res in self.attention_resolutions:
62                valid_res.append(current_res)
63            current_res //= 2
64        if valid_res != list(self.attention_resolutions):
65            print(f"Warning: Adjusted attention_resolutions from {self.attention_resolutions} to {tuple(valid_res)}")
66
67    @property
68    def model_channels(self) -> Tuple[int, ...]:
69        """Get channel count at each resolution level."""
70        return tuple(self.base_channels * m for m in self.channel_multipliers)
71
72    def to_dict(self) -> dict:
73        """Convert to dictionary for serialization."""
74        return {k: v for k, v in self.__dict__.items()}
75
76    def save(self, path: str):
77        """Save configuration to JSON file."""
78        with open(path, 'w') as f:
79            json.dump(self.to_dict(), f, indent=2)
80
81    @classmethod
82    def load(cls, path: str) -> 'DiffusionModelConfig':
83        """Load configuration from JSON file."""
84        with open(path, 'r') as f:
85            data = json.load(f)
86        return cls(**data)
87
88    @classmethod
89    def cifar10(cls) -> 'DiffusionModelConfig':
90        """Standard configuration for CIFAR-10."""
91        return cls(
92            image_size=32,
93            base_channels=64,
94            channel_multipliers=(1, 2, 4),
95            attention_resolutions=(16, 8),
96            batch_size=128,
97            num_epochs=200,
98        )
99
100    @classmethod
101    def celeba64(cls) -> 'DiffusionModelConfig':
102        """Standard configuration for CelebA 64x64."""
103        return cls(
104            image_size=64,
105            base_channels=64,
106            channel_multipliers=(1, 2, 4, 8),
107            attention_resolutions=(16, 8),
108            batch_size=64,
109            num_epochs=100,
110        )
111
112    @classmethod
113    def imagenet256(cls) -> 'DiffusionModelConfig':
114        """Configuration for ImageNet 256x256."""
115        return cls(
116            image_size=256,
117            base_channels=128,
118            channel_multipliers=(1, 2, 3, 4),
119            attention_resolutions=(32, 16, 8),
120            num_attention_heads=8,
121            batch_size=8,
122            num_epochs=400,
123            gradient_accumulation=8,
124        )
125
126
127# Usage examples
128cifar_config = DiffusionModelConfig.cifar10()
129print(f"CIFAR-10 config: {cifar_config.model_channels}")
130
131celeba_config = DiffusionModelConfig.celeba64()
132print(f"CelebA config: {celeba_config.model_channels}")

Key Takeaways

  1. Channel multipliers control capacity: Start with base channels of 32-64 for small images, 64-128 for larger. Double or quadruple at lower resolutions.
  2. Attention is expensive: Only apply at resolutions 32x32 and below. Memory scales quadratically with resolution.
  3. Use 1000 timesteps: This is the standard choice. Cosine schedule often outperforms linear for generation quality.
  4. Match model to compute: A smaller model trained longer often beats a large model trained briefly. Size your model to fit comfortably in GPU memory with room for decent batch sizes.
  5. Use configuration classes: Centralize all hyperparameters in a dataclass for reproducibility and easy experimentation.
Looking Ahead: In the next section, we'll put together a complete training script with all the components: model, diffusion process, optimizer, logging, and checkpointing.