Learning Objectives
By the end of this section, you will be able to:
- Configure channel multipliers to control model capacity at different resolutions
- Choose attention resolutions that balance quality and computational cost
- Select appropriate timestep counts and noise schedules for your dataset
- Size your model appropriately for your GPU memory and training budget
The Big Picture
A diffusion model's architecture determines both its capacity to learn complex patterns and its computational requirements. The same U-Net backbone can range from a 5M parameter model trainable on a laptop to a 1B+ parameter model requiring a cluster of A100s.
The Key Insight: Model configuration is about finding the sweet spot between model capacity and your available compute. A larger model isn't always better - training a properly-sized model longer often beats training an oversized model briefly.
In this section, we'll learn how to configure each component of the U-Net architecture to match your computational resources and quality targets.
Channel Multipliers
Channel multipliers control the width of the network at each resolution level. The U-Net processes images at progressively lower resolutions, with the channel count typically increasing as resolution decreases.
Understanding Channel Progression
For a base channel count of with multipliers, the channel counts at each level are:
| Level | Resolution (64x64 input) | Multiplier | Channels |
|---|---|---|---|
| 0 | 64x64 | 1 | 64 |
| 1 | 32x32 | 2 | 128 |
| 2 | 16x16 | 4 | 256 |
| 3 | 8x8 | 8 | 512 |
1from dataclasses import dataclass
2from typing import Tuple, List
3
4@dataclass
5class UNetConfig:
6 """Configuration for diffusion U-Net architecture."""
7
8 # Input/output
9 image_size: int = 64
10 in_channels: int = 3
11 out_channels: int = 3
12
13 # Architecture
14 base_channels: int = 64
15 channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8)
16
17 # Attention
18 attention_resolutions: Tuple[int, ...] = (16, 8)
19 num_attention_heads: int = 4
20
21 # Blocks
22 num_res_blocks: int = 2
23 dropout: float = 0.0
24
25 # Time embedding
26 time_embed_dim: int = None # Defaults to 4 * base_channels
27
28 def __post_init__(self):
29 if self.time_embed_dim is None:
30 self.time_embed_dim = self.base_channels * 4
31
32 @property
33 def channel_counts(self) -> List[int]:
34 """Get channel counts at each resolution level."""
35 return [self.base_channels * m for m in self.channel_multipliers]
36
37 @property
38 def num_parameters(self) -> int:
39 """Rough estimate of parameter count (millions)."""
40 # This is a simplified estimate
41 base = self.base_channels ** 2 * sum(m**2 for m in self.channel_multipliers)
42 return int(base * self.num_res_blocks * 0.01) # Very rough
43
44
45# Example configurations
46small_config = UNetConfig(
47 base_channels=32,
48 channel_multipliers=(1, 2, 4),
49)
50
51medium_config = UNetConfig(
52 base_channels=64,
53 channel_multipliers=(1, 2, 4, 8),
54)
55
56large_config = UNetConfig(
57 base_channels=128,
58 channel_multipliers=(1, 2, 4, 8),
59)
60
61print(f"Small: {small_config.channel_counts}") # [32, 64, 128]
62print(f"Medium: {medium_config.channel_counts}") # [64, 128, 256, 512]
63print(f"Large: {large_config.channel_counts}") # [128, 256, 512, 1024]Guidelines for Channel Selection
Choosing channel counts depends on your image resolution and compute budget:
| Image Size | Base Channels | Multipliers | Est. Params |
|---|---|---|---|
| 28x28 (MNIST) | 32 | (1, 2, 4) | ~2M |
| 32x32 (CIFAR) | 64 | (1, 2, 4) | ~15M |
| 64x64 | 64 | (1, 2, 4, 8) | ~35M |
| 128x128 | 96 | (1, 2, 4, 8) | ~100M |
| 256x256 | 128 | (1, 2, 3, 4) | ~200M |
| 512x512 | 160 | (1, 2, 2, 4) | ~400M |
Attention Resolution
Self-attention is the most expensive operation in the U-Net. We typically only apply it at lower resolutions where the feature maps are small:
1def should_use_attention(
2 resolution: int,
3 attention_resolutions: Tuple[int, ...],
4) -> bool:
5 """
6 Determine if attention should be used at this resolution.
7
8 Args:
9 resolution: Current feature map resolution (e.g., 32, 16, 8)
10 attention_resolutions: Resolutions where attention is applied
11
12 Returns:
13 True if attention should be used
14 """
15 return resolution in attention_resolutions
16
17
18# Cost analysis of attention
19def attention_memory_cost(resolution: int, channels: int, batch_size: int) -> float:
20 """
21 Estimate memory cost of self-attention in GB.
22
23 Self-attention has O(n^2) memory where n = resolution^2
24 """
25 seq_len = resolution ** 2 # Number of spatial positions
26 # Q, K, V matrices: 3 * batch * seq * channels * 4 bytes
27 qkv_memory = 3 * batch_size * seq_len * channels * 4
28 # Attention scores: batch * heads * seq * seq * 4 bytes
29 attn_memory = batch_size * seq_len * seq_len * 4
30
31 total_bytes = qkv_memory + attn_memory
32 return total_bytes / (1024 ** 3) # Convert to GB
33
34
35# Example: Why we avoid attention at high resolutions
36for res in [64, 32, 16, 8]:
37 cost = attention_memory_cost(res, channels=256, batch_size=32)
38 print(f"Resolution {res}x{res}: {cost:.2f} GB")
39
40# Output:
41# Resolution 64x64: 4.00 GB
42# Resolution 32x32: 0.25 GB
43# Resolution 16x16: 0.02 GB
44# Resolution 8x8: 0.00 GBAttention Trade-offs
Attention Configuration Patterns
1# Common attention patterns for different image sizes
2
3def get_attention_config(image_size: int) -> Tuple[int, ...]:
4 """
5 Get recommended attention resolutions for image size.
6
7 Returns resolutions where attention should be applied.
8 """
9 if image_size <= 32:
10 # Small images: attention at all levels
11 return (16, 8, 4) if image_size == 32 else (8, 4)
12
13 elif image_size == 64:
14 # Medium images: attention at 16x16 and below
15 return (16, 8)
16
17 elif image_size == 128:
18 # Larger images: attention at 32x32 and below
19 return (32, 16, 8)
20
21 elif image_size == 256:
22 # Large images: attention at 32x32 and below only
23 return (32, 16)
24
25 else:
26 # Very large: be conservative
27 return (32, 16)
28
29
30# Number of attention heads
31def get_num_heads(channels: int) -> int:
32 """
33 Determine number of attention heads based on channel count.
34
35 Each head should have at least 32-64 dimensions.
36 """
37 head_dim = 64 # Typical head dimension
38 num_heads = max(1, channels // head_dim)
39 return num_heads
40
41print(f"256 channels -> {get_num_heads(256)} heads") # 4 heads
42print(f"512 channels -> {get_num_heads(512)} heads") # 8 headsTimesteps and Schedule
The number of diffusion timesteps and the noise schedule affect both training stability and generation quality:
1import torch
2import numpy as np
3
4@dataclass
5class DiffusionConfig:
6 """Configuration for diffusion process."""
7
8 # Number of timesteps
9 timesteps: int = 1000
10
11 # Noise schedule type
12 schedule: str = "linear" # "linear", "cosine", "sqrt"
13
14 # Schedule parameters
15 beta_start: float = 0.0001
16 beta_end: float = 0.02
17
18 # Loss weighting
19 loss_type: str = "mse" # "mse", "l1", "huber"
20 prediction_type: str = "epsilon" # "epsilon", "v", "x0"
21
22
23def create_beta_schedule(config: DiffusionConfig) -> torch.Tensor:
24 """Create noise schedule based on config."""
25 T = config.timesteps
26
27 if config.schedule == "linear":
28 return torch.linspace(config.beta_start, config.beta_end, T)
29
30 elif config.schedule == "cosine":
31 # Cosine schedule from "Improved Denoising Diffusion" paper
32 s = 0.008
33 steps = torch.linspace(0, T, T + 1)
34 alphas_cumprod = torch.cos((steps / T + s) / (1 + s) * np.pi / 2) ** 2
35 alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
36 betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
37 return torch.clamp(betas, 0.0001, 0.9999)
38
39 elif config.schedule == "sqrt":
40 return torch.linspace(config.beta_start**0.5, config.beta_end**0.5, T) ** 2
41
42 else:
43 raise ValueError(f"Unknown schedule: {config.schedule}")
44
45
46# Compare schedules
47linear_config = DiffusionConfig(schedule="linear")
48cosine_config = DiffusionConfig(schedule="cosine")
49
50linear_betas = create_beta_schedule(linear_config)
51cosine_betas = create_beta_schedule(cosine_config)
52
53# Compute alpha_bar (cumulative product of 1-beta)
54linear_alpha_bar = torch.cumprod(1 - linear_betas, dim=0)
55cosine_alpha_bar = torch.cumprod(1 - cosine_betas, dim=0)
56
57print(f"Linear schedule: alpha_bar ranges from {linear_alpha_bar[-1]:.4f} to {linear_alpha_bar[0]:.4f}")
58print(f"Cosine schedule: alpha_bar ranges from {cosine_alpha_bar[-1]:.4f} to {cosine_alpha_bar[0]:.4f}")Schedule Selection Guidelines
| Schedule | Best For | Characteristics |
|---|---|---|
| Linear | Quick experiments | Simple, fast to compute, may struggle with details |
| Cosine | Most applications | Better detail preservation, slower at high noise |
| Sqrt | Large images | Gentler at low noise, good for high-res |
The number of timesteps affects the trade-off between quality and training cost:
- T=1000: Standard for most applications, good quality
- T=500: Faster training, slight quality reduction
- T=2000+: Marginal quality gains, 2x compute cost
Model Sizing Guidelines
Choosing the right model size depends on your dataset and compute budget:
1@dataclass
2class TrainingBudget:
3 """Estimate training requirements for different configurations."""
4
5 gpu_memory_gb: float
6 batch_size: int
7 training_time_hours: float # For ~100 epochs on CIFAR-10
8
9
10# Approximate requirements (single GPU)
11budgets = {
12 "tiny": TrainingBudget(gpu_memory_gb=4, batch_size=64, training_time_hours=2),
13 "small": TrainingBudget(gpu_memory_gb=8, batch_size=64, training_time_hours=6),
14 "medium": TrainingBudget(gpu_memory_gb=16, batch_size=128, training_time_hours=12),
15 "large": TrainingBudget(gpu_memory_gb=24, batch_size=64, training_time_hours=24),
16 "xl": TrainingBudget(gpu_memory_gb=40, batch_size=32, training_time_hours=48),
17}
18
19
20def recommend_config(
21 gpu_memory_gb: float,
22 image_size: int,
23 dataset_size: int,
24) -> UNetConfig:
25 """
26 Recommend model configuration based on hardware and data.
27
28 Args:
29 gpu_memory_gb: Available GPU memory
30 image_size: Target image resolution
31 dataset_size: Number of training samples
32
33 Returns:
34 Recommended UNetConfig
35 """
36 # Scale base channels by GPU memory
37 if gpu_memory_gb >= 40:
38 base = 128
39 elif gpu_memory_gb >= 16:
40 base = 64
41 elif gpu_memory_gb >= 8:
42 base = 48
43 else:
44 base = 32
45
46 # Adjust for image size (larger images need more memory for activations)
47 if image_size >= 256:
48 base = min(base, 96)
49 elif image_size >= 128:
50 base = min(base, 128)
51
52 # Select channel multipliers based on resolution
53 if image_size <= 32:
54 mults = (1, 2, 4)
55 elif image_size <= 64:
56 mults = (1, 2, 4, 8)
57 elif image_size <= 128:
58 mults = (1, 2, 3, 4)
59 else:
60 mults = (1, 2, 2, 4)
61
62 # Select attention resolutions
63 attention_res = get_attention_config(image_size)
64
65 return UNetConfig(
66 image_size=image_size,
67 base_channels=base,
68 channel_multipliers=mults,
69 attention_resolutions=attention_res,
70 )
71
72
73# Example
74config = recommend_config(gpu_memory_gb=16, image_size=64, dataset_size=50000)
75print(f"Recommended: {config.base_channels} base channels")
76print(f"Channel progression: {config.channel_counts}")Configuration Dataclass
Here's a complete configuration class that combines all settings:
1from dataclasses import dataclass, field
2from typing import Tuple, Optional
3import json
4
5@dataclass
6class DiffusionModelConfig:
7 """
8 Complete configuration for diffusion model training.
9
10 This dataclass captures all hyperparameters needed to:
11 - Define the U-Net architecture
12 - Configure the diffusion process
13 - Set up training parameters
14 """
15
16 # ===== Data =====
17 image_size: int = 64
18 in_channels: int = 3
19 dataset: str = "cifar10"
20
21 # ===== U-Net Architecture =====
22 base_channels: int = 64
23 channel_multipliers: Tuple[int, ...] = (1, 2, 4, 8)
24 num_res_blocks: int = 2
25 attention_resolutions: Tuple[int, ...] = (16, 8)
26 num_attention_heads: int = 4
27 dropout: float = 0.0
28
29 # ===== Diffusion Process =====
30 timesteps: int = 1000
31 noise_schedule: str = "linear"
32 beta_start: float = 0.0001
33 beta_end: float = 0.02
34 prediction_type: str = "epsilon" # "epsilon", "v", "x0"
35
36 # ===== Training =====
37 batch_size: int = 64
38 learning_rate: float = 2e-4
39 num_epochs: int = 100
40 warmup_steps: int = 1000
41 gradient_accumulation: int = 1
42 use_ema: bool = True
43 ema_decay: float = 0.9999
44
45 # ===== Optimization =====
46 optimizer: str = "adamw"
47 weight_decay: float = 0.01
48 use_amp: bool = True
49
50 # ===== Checkpointing =====
51 checkpoint_every: int = 10
52 sample_every: int = 5
53 num_samples: int = 16
54
55 def __post_init__(self):
56 """Validate and adjust configuration."""
57 # Ensure attention resolutions are valid
58 current_res = self.image_size
59 valid_res = []
60 for mult in self.channel_multipliers:
61 if current_res in self.attention_resolutions:
62 valid_res.append(current_res)
63 current_res //= 2
64 if valid_res != list(self.attention_resolutions):
65 print(f"Warning: Adjusted attention_resolutions from {self.attention_resolutions} to {tuple(valid_res)}")
66
67 @property
68 def model_channels(self) -> Tuple[int, ...]:
69 """Get channel count at each resolution level."""
70 return tuple(self.base_channels * m for m in self.channel_multipliers)
71
72 def to_dict(self) -> dict:
73 """Convert to dictionary for serialization."""
74 return {k: v for k, v in self.__dict__.items()}
75
76 def save(self, path: str):
77 """Save configuration to JSON file."""
78 with open(path, 'w') as f:
79 json.dump(self.to_dict(), f, indent=2)
80
81 @classmethod
82 def load(cls, path: str) -> 'DiffusionModelConfig':
83 """Load configuration from JSON file."""
84 with open(path, 'r') as f:
85 data = json.load(f)
86 return cls(**data)
87
88 @classmethod
89 def cifar10(cls) -> 'DiffusionModelConfig':
90 """Standard configuration for CIFAR-10."""
91 return cls(
92 image_size=32,
93 base_channels=64,
94 channel_multipliers=(1, 2, 4),
95 attention_resolutions=(16, 8),
96 batch_size=128,
97 num_epochs=200,
98 )
99
100 @classmethod
101 def celeba64(cls) -> 'DiffusionModelConfig':
102 """Standard configuration for CelebA 64x64."""
103 return cls(
104 image_size=64,
105 base_channels=64,
106 channel_multipliers=(1, 2, 4, 8),
107 attention_resolutions=(16, 8),
108 batch_size=64,
109 num_epochs=100,
110 )
111
112 @classmethod
113 def imagenet256(cls) -> 'DiffusionModelConfig':
114 """Configuration for ImageNet 256x256."""
115 return cls(
116 image_size=256,
117 base_channels=128,
118 channel_multipliers=(1, 2, 3, 4),
119 attention_resolutions=(32, 16, 8),
120 num_attention_heads=8,
121 batch_size=8,
122 num_epochs=400,
123 gradient_accumulation=8,
124 )
125
126
127# Usage examples
128cifar_config = DiffusionModelConfig.cifar10()
129print(f"CIFAR-10 config: {cifar_config.model_channels}")
130
131celeba_config = DiffusionModelConfig.celeba64()
132print(f"CelebA config: {celeba_config.model_channels}")Key Takeaways
- Channel multipliers control capacity: Start with base channels of 32-64 for small images, 64-128 for larger. Double or quadruple at lower resolutions.
- Attention is expensive: Only apply at resolutions 32x32 and below. Memory scales quadratically with resolution.
- Use 1000 timesteps: This is the standard choice. Cosine schedule often outperforms linear for generation quality.
- Match model to compute: A smaller model trained longer often beats a large model trained briefly. Size your model to fit comfortably in GPU memory with room for decent batch sizes.
- Use configuration classes: Centralize all hyperparameters in a dataclass for reproducibility and easy experimentation.
Looking Ahead: In the next section, we'll put together a complete training script with all the components: model, diffusion process, optimizer, logging, and checkpointing.