Learning Objectives
By the end of this section, you will:
- Assemble a complete U-Net by combining ResBlocks, time conditioning, attention, and skip connections
- Implement the encoder path that progressively downsamples while capturing features
- Build the decoder path that upsamples while fusing skip connections
- Design the bottleneck that processes the most compressed representation
- Configure model capacity through channel multipliers, depth, and attention placement
- Initialize weights properly for stable diffusion training
The Grand Assembly
Putting It All Together
Let's recap the components we've built and understand how they fit together:
| Component | Section | Purpose in U-Net |
|---|---|---|
| ResBlock | 5.2 | Basic feature transformation with residual connections |
| GroupNorm + SiLU | 5.2 | Normalization and activation for stable training |
| Downsample/Upsample | 5.2 | Resolution changes in encoder/decoder |
| Sinusoidal Embedding | 5.3 | Encode timestep as continuous vectors |
| Time MLP | 5.3 | Project time embeddings to network dimension |
| Time Conditioning | 5.3 | Inject time information into ResBlocks |
| Self-Attention | 5.4 | Capture long-range spatial dependencies |
| Skip Connections | 5.1 | Preserve fine details across encoder-decoder |
The U-Net architecture can be visualized as an hourglass where:
- The encoder (left side) progressively compresses spatial information while expanding channels
- The bottleneck (bottom) processes the most abstract representation at lowest resolution
- The decoder (right side) progressively expands spatial resolution while reducing channels
- Skip connections (horizontal arrows) connect encoder to decoder at matching resolutions
Interactive diagram showing the encoder-decoder architecture with skip connections and time conditioning
Why This Architecture Works for Diffusion
Architecture Overview
Before diving into code, let's establish the data flow through our U-Net:
1Input: Noisy image x_t [B, 3, H, W] + Timestep t [B]
2
31. TIME EMBEDDING:
4 t -> Sinusoidal(t) -> MLP -> time_emb [B, time_dim]
5
62. INITIAL PROJECTION:
7 x_t -> Conv3x3 -> h [B, base_ch, H, W]
8
93. ENCODER (contracting path):
10 Level 0: H x W
11 -> ResBlock(h, time_emb) -> skip_0
12 -> ResBlock(h, time_emb) -> skip_1
13 -> Downsample -> h [B, ch_1, H/2, W/2]
14
15 Level 1: H/2 x W/2
16 -> ResBlock(h, time_emb) -> skip_2
17 -> Attention(h) -> skip_3
18 -> Downsample -> h [B, ch_2, H/4, W/4]
19
20 Level 2: H/4 x W/4
21 -> ResBlock(h, time_emb) -> skip_4
22 -> Attention(h) -> skip_5
23 -> Downsample -> h [B, ch_3, H/8, W/8]
24
254. BOTTLENECK:
26 -> ResBlock(h, time_emb)
27 -> Attention(h)
28 -> ResBlock(h, time_emb)
29
305. DECODER (expanding path):
31 Level 2: H/8 x W/8
32 -> Upsample -> h [B, ch_2, H/4, W/4]
33 -> Concat(h, skip_5) -> ResBlock
34 -> Concat(h, skip_4) -> ResBlock + Attention
35
36 Level 1: H/4 x W/4
37 -> Upsample -> h [B, ch_1, H/2, W/2]
38 -> Concat(h, skip_3) -> ResBlock + Attention
39 -> Concat(h, skip_2) -> ResBlock
40
41 Level 0: H/2 x W/2
42 -> Upsample -> h [B, base_ch, H, W]
43 -> Concat(h, skip_1) -> ResBlock
44 -> Concat(h, skip_0) -> ResBlock
45
466. OUTPUT:
47 h -> GroupNorm -> SiLU -> Conv3x3 -> noise_pred [B, 3, H, W]Understanding Skip Connections
Encoder Implementation
The encoder is the contracting path that processes the input at progressively lower resolutions. At each level, we apply ResBlocks (and optionally attention), save the output as a skip connection, then downsample.
Key design decisions in the encoder:
- Save skips after attention: Skip connections include the attended features
- Downsample at the end: Process at current resolution, then reduce for next level
- Channel increase at first ResBlock: Each level may have more channels than the previous
The Bottleneck
The bottleneck sits at the lowest resolution (typically 4x4 or 8x8 for 64x64 images). This is where the most abstract, compressed representation exists. It typically contains:
- Multiple ResBlocks for deep feature processing
- Attention layers (computationally cheap at low resolution)
- No downsampling or upsampling
1class Bottleneck(nn.Module):
2 """
3 Middle block at the lowest resolution.
4 Contains ResBlocks with attention for deep processing.
5 """
6 def __init__(
7 self,
8 channels: int,
9 time_emb_dim: int,
10 dropout: float = 0.0,
11 ):
12 super().__init__()
13
14 self.block1 = ResBlock(channels, channels, time_emb_dim, dropout)
15 self.attention = AttentionBlock(channels)
16 self.block2 = ResBlock(channels, channels, time_emb_dim, dropout)
17
18 def forward(
19 self, x: torch.Tensor, time_emb: torch.Tensor
20 ) -> torch.Tensor:
21 x = self.block1(x, time_emb)
22 x = self.attention(x)
23 x = self.block2(x, time_emb)
24 return xWhy Attention at the Bottleneck?
Decoder Implementation
The decoder mirrors the encoder but works in reverse: it upsamples and concatenates skip connections at each level.
Skip Connection Alignment
Complete U-Net Class
Now let's assemble everything into a complete, production-ready U-Net:
Let's verify our U-Net works correctly:
1# Test the U-Net
2model = UNet(
3 image_size=64,
4 in_channels=3,
5 out_channels=3,
6 base_channels=128,
7 channel_mults=(1, 2, 2, 4),
8 num_res_blocks=2,
9 attention_resolutions=(16, 8),
10)
11
12# Count parameters
13num_params = sum(p.numel() for p in model.parameters())
14print(f"Number of parameters: {num_params:,}") # ~35M for this config
15
16# Test forward pass
17x = torch.randn(2, 3, 64, 64) # Batch of 2 images
18t = torch.randint(0, 1000, (2,)) # Random timesteps
19
20noise_pred = model(x, t)
21print(f"Input shape: {x.shape}")
22print(f"Output shape: {noise_pred.shape}") # Should match input!
23
24assert noise_pred.shape == x.shape, "U-Net should preserve spatial dimensions!"Model Configuration
Different applications require different U-Net configurations. Here are common setups:
| Config | Image Size | Parameters | Training Time | Quality |
|---|---|---|---|---|
| Light | 64x64 | ~8M | Fast | Good for prototyping |
| DDPM | 64x64 | ~35M | Moderate | Publication quality |
| Improved | 64x64 | ~113M | Slow | State-of-the-art |
| HighRes | 256x256 | ~100M+ | Very slow | High-resolution |
Choosing a Configuration
Parameter Counting and Scaling
Understanding how parameters scale helps you design efficient architectures:
1def count_parameters_by_component(model):
2 """Break down parameters by U-Net component."""
3
4 components = {
5 'time_embed': 0,
6 'init_conv': 0,
7 'encoder': 0,
8 'bottleneck': 0,
9 'decoder': 0,
10 'output': 0,
11 }
12
13 for name, param in model.named_parameters():
14 count = param.numel()
15 if 'time_embed' in name:
16 components['time_embed'] += count
17 elif 'init_conv' in name:
18 components['init_conv'] += count
19 elif 'encoder' in name:
20 components['encoder'] += count
21 elif 'bottleneck' in name:
22 components['bottleneck'] += count
23 elif 'decoder' in name:
24 components['decoder'] += count
25 else:
26 components['output'] += count
27
28 total = sum(components.values())
29 print("\nParameter breakdown:")
30 for component, count in components.items():
31 pct = 100 * count / total
32 print(f" {component:12s}: {count:>10,} ({pct:5.1f}%)")
33 print(f" {'TOTAL':12s}: {total:>10,}")
34
35 return components
36
37# Analyze the DDPM model
38model = UNet(**ddpm_config)
39count_parameters_by_component(model)
40
41# Output:
42# Parameter breakdown:
43# time_embed : 132,352 ( 0.4%)
44# init_conv : 3,584 ( 0.0%)
45# encoder : 17,234,432 (49.2%)
46# bottleneck : 2,363,392 ( 6.7%)
47# decoder : 15,234,816 (43.5%)
48# output : 66,560 ( 0.2%)
49# TOTAL : 35,035,136Key observations about parameter distribution:
- Encoder and decoder contain ~93% of parameters (nearly equal between them)
- Time embedding is cheap (<1% of parameters)
- Bottleneck is relatively small despite being computationally important
- Attention layers (not shown separately) typically account for 10-20% of total parameters
Scaling Laws
How do parameters scale with architecture choices?
| Change | Parameter Impact | Quality Impact |
|---|---|---|
| 2x base_channels | ~4x parameters | Significant improvement |
| 2x num_res_blocks | ~2x parameters | Moderate improvement |
| Add attention at 32x32 | ~1.5x at that level | Helps global coherence |
| Remove attention at 8x8 | Minor reduction | Usually safe to skip |
| Add resolution level | ~1.5-2x parameters | Required for higher res |
Weight Initialization
Proper initialization is crucial for stable training. Diffusion models benefit from specific initialization strategies:
1import torch.nn as nn
2
3def initialize_weights(model: nn.Module):
4 """
5 Initialize U-Net weights for stable diffusion training.
6
7 Key principles:
8 1. Most layers: Default PyTorch initialization (Kaiming/He)
9 2. Output layer: Initialize to zero (start by predicting zero noise)
10 3. Residual projections: Can be scaled down for stability
11 """
12
13 for name, module in model.named_modules():
14 if isinstance(module, nn.Conv2d):
15 # Default: Kaiming initialization for convolutions
16 nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
17 if module.bias is not None:
18 nn.init.zeros_(module.bias)
19
20 elif isinstance(module, nn.Linear):
21 # Xavier/Glorot for linear layers
22 nn.init.xavier_uniform_(module.weight)
23 if module.bias is not None:
24 nn.init.zeros_(module.bias)
25
26 elif isinstance(module, nn.GroupNorm):
27 # Standard: gamma=1, beta=0
28 nn.init.ones_(module.weight)
29 nn.init.zeros_(module.bias)
30
31 # CRITICAL: Initialize output convolution to zero
32 # This means the model initially predicts zero noise,
33 # which is a reasonable starting point (especially for t close to 0)
34 if hasattr(model, 'out_conv'):
35 nn.init.zeros_(model.out_conv.weight)
36 nn.init.zeros_(model.out_conv.bias)
37
38 return model
39
40
41# Apply initialization
42model = UNet(**ddpm_config)
43model = initialize_weights(model)
44
45# Verify output is initially zero
46x = torch.randn(1, 3, 64, 64)
47t = torch.zeros(1, dtype=torch.long)
48with torch.no_grad():
49 output = model(x, t)
50 print(f"Initial output mean: {output.mean():.6f}") # Should be ~0
51 print(f"Initial output std: {output.std():.6f}") # Should be ~0Zero Output Initialization
Summary
In this section, we assembled a complete, production-ready U-Net for diffusion models:
- Architecture overview: The U-Net flows from input through encoder, bottleneck, decoder, and output, with skip connections bridging encoder to decoder
- Encoder implementation: Progressive downsampling with ResBlocks and attention, saving skip connections at each resolution
- Bottleneck design: Deep processing at lowest resolution with attention for global reasoning
- Decoder implementation: Progressive upsampling with skip concatenation to recover fine details
- Model configuration: Different configs for different use cases, from lightweight prototyping to publication-quality models
- Weight initialization: Zero output initialization for stable training
Chapter Complete!
The complete U-Net code from this chapter is available in the repository and can be used as a foundation for your own diffusion model experiments. The architecture we've built supports:
- Any image resolution (with appropriate channel multipliers)
- Configurable depth and width
- Flexible attention placement
- Time-conditional generation
- Skip connections for preserving details
In the next chapter, we'll train this U-Net to denoise images and generate new samples from pure noise!