Chapter 5
30 min read
Section 28 of 76

Complete U-Net Implementation

U-Net Architecture for Diffusion

Learning Objectives

By the end of this section, you will:

  1. Assemble a complete U-Net by combining ResBlocks, time conditioning, attention, and skip connections
  2. Implement the encoder path that progressively downsamples while capturing features
  3. Build the decoder path that upsamples while fusing skip connections
  4. Design the bottleneck that processes the most compressed representation
  5. Configure model capacity through channel multipliers, depth, and attention placement
  6. Initialize weights properly for stable diffusion training

The Grand Assembly

This section brings together everything from Chapter 5. We've learned about skip connections, ResBlocks, time conditioning, and attention individually. Now we assemble them into a production-ready U-Net that can serve as the backbone for any diffusion model. The code here is what you'll actually use to train models on real datasets.

Putting It All Together

Let's recap the components we've built and understand how they fit together:

ComponentSectionPurpose in U-Net
ResBlock5.2Basic feature transformation with residual connections
GroupNorm + SiLU5.2Normalization and activation for stable training
Downsample/Upsample5.2Resolution changes in encoder/decoder
Sinusoidal Embedding5.3Encode timestep as continuous vectors
Time MLP5.3Project time embeddings to network dimension
Time Conditioning5.3Inject time information into ResBlocks
Self-Attention5.4Capture long-range spatial dependencies
Skip Connections5.1Preserve fine details across encoder-decoder

The U-Net architecture can be visualized as an hourglass where:

  • The encoder (left side) progressively compresses spatial information while expanding channels
  • The bottleneck (bottom) processes the most abstract representation at lowest resolution
  • The decoder (right side) progressively expands spatial resolution while reducing channels
  • Skip connections (horizontal arrows) connect encoder to decoder at matching resolutions
U-Net Architecture for Diffusion Models

Interactive diagram showing the encoder-decoder architecture with skip connections and time conditioning

Animation Step:
Time EmbeddingInputx_t[B,3,H,W]ConvInDown164 to 128Down2128 to 256Down3256 to 512Bottleneck512 chUp1512 to 256Up2256 to 128Up3128 to 64ConvOutOutputepsilonSkip ConnectionEncoderBottleneckDecoderH/2, W/2H/4, W/4H/8, W/8H/8, W/8H/4, W/4H/2, W/2
Encoder (Downsampling)
Bottleneck
Decoder (Upsampling)
Skip Connections

Why This Architecture Works for Diffusion

Diffusion models need to predict noise at every pixel while understanding global image structure. The multi-scale processing captures both local textures (high resolution) and global composition (low resolution). Skip connections ensure the fine details from the encoder directly inform the decoder's output, which is critical for predicting pixel-accurate noise.

Architecture Overview

Before diving into code, let's establish the data flow through our U-Net:

📝text
1Input: Noisy image x_t [B, 3, H, W] + Timestep t [B]
2
31. TIME EMBEDDING:
4   t -> Sinusoidal(t) -> MLP -> time_emb [B, time_dim]
5
62. INITIAL PROJECTION:
7   x_t -> Conv3x3 -> h [B, base_ch, H, W]
8
93. ENCODER (contracting path):
10   Level 0: H x W
11     -> ResBlock(h, time_emb) -> skip_0
12     -> ResBlock(h, time_emb) -> skip_1
13     -> Downsample -> h [B, ch_1, H/2, W/2]
14
15   Level 1: H/2 x W/2
16     -> ResBlock(h, time_emb) -> skip_2
17     -> Attention(h) -> skip_3
18     -> Downsample -> h [B, ch_2, H/4, W/4]
19
20   Level 2: H/4 x W/4
21     -> ResBlock(h, time_emb) -> skip_4
22     -> Attention(h) -> skip_5
23     -> Downsample -> h [B, ch_3, H/8, W/8]
24
254. BOTTLENECK:
26   -> ResBlock(h, time_emb)
27   -> Attention(h)
28   -> ResBlock(h, time_emb)
29
305. DECODER (expanding path):
31   Level 2: H/8 x W/8
32     -> Upsample -> h [B, ch_2, H/4, W/4]
33     -> Concat(h, skip_5) -> ResBlock
34     -> Concat(h, skip_4) -> ResBlock + Attention
35
36   Level 1: H/4 x W/4
37     -> Upsample -> h [B, ch_1, H/2, W/2]
38     -> Concat(h, skip_3) -> ResBlock + Attention
39     -> Concat(h, skip_2) -> ResBlock
40
41   Level 0: H/2 x W/2
42     -> Upsample -> h [B, base_ch, H, W]
43     -> Concat(h, skip_1) -> ResBlock
44     -> Concat(h, skip_0) -> ResBlock
45
466. OUTPUT:
47   h -> GroupNorm -> SiLU -> Conv3x3 -> noise_pred [B, 3, H, W]

Understanding Skip Connections

Skip connections are stored during encoding and consumed during decoding in reverse order. The last skip saved (lowest resolution) is the first one used. This ensures each decoder level receives features from the corresponding encoder level at the same spatial resolution.

Encoder Implementation

The encoder is the contracting path that processes the input at progressively lower resolutions. At each level, we apply ResBlocks (and optionally attention), save the output as a skip connection, then downsample.

Encoder: The Contracting Path
🐍encoder.py
1Encoder Module

The encoder is the contracting path of the U-Net. It progressively reduces spatial resolution while increasing channel count to learn hierarchical features.

10Initial Projection

A 3x3 convolution projects the input image from 3 RGB channels to the base channel count (e.g., 128). This lifts the image into the feature space.

14Building Encoder Levels

We iterate through each resolution level, adding ResBlocks and optionally attention. Each level processes features at its spatial resolution before downsampling.

18ResBlocks per Level

Typically 2-3 ResBlocks per resolution level. More blocks = more capacity but also more computation. The channel count may increase at each level.

24Optional Attention

Attention is typically only added at lower resolutions (e.g., 16x16, 8x8) where it is computationally feasible and most beneficial for global context.

30Downsampling

After processing at each level (except the last), we downsample by 2x. This creates the contracting structure characteristic of U-Net.

56 lines without explanation
1import torch
2import torch.nn as nn
3from typing import List, Tuple
4
5class EncoderBlock(nn.Module):
6    """
7    Single encoder level: ResBlocks + optional Attention + Downsample.
8    """
9    def __init__(
10        self,
11        in_channels: int,
12        out_channels: int,
13        time_emb_dim: int,
14        num_res_blocks: int = 2,
15        use_attention: bool = False,
16        downsample: bool = True,
17        dropout: float = 0.0,
18    ):
19        super().__init__()
20
21        self.res_blocks = nn.ModuleList()
22        self.attention_blocks = nn.ModuleList()
23
24        # First ResBlock may change channels
25        self.res_blocks.append(
26            ResBlock(in_channels, out_channels, time_emb_dim, dropout)
27        )
28        if use_attention:
29            self.attention_blocks.append(AttentionBlock(out_channels))
30        else:
31            self.attention_blocks.append(nn.Identity())
32
33        # Subsequent ResBlocks maintain channels
34        for _ in range(num_res_blocks - 1):
35            self.res_blocks.append(
36                ResBlock(out_channels, out_channels, time_emb_dim, dropout)
37            )
38            if use_attention:
39                self.attention_blocks.append(AttentionBlock(out_channels))
40            else:
41                self.attention_blocks.append(nn.Identity())
42
43        # Downsampling
44        self.downsample = Downsample(out_channels) if downsample else nn.Identity()
45
46    def forward(
47        self, x: torch.Tensor, time_emb: torch.Tensor
48    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
49        """
50        Returns:
51            x: Downsampled features
52            skips: List of intermediate features for skip connections
53        """
54        skips = []
55
56        for res_block, attn in zip(self.res_blocks, self.attention_blocks):
57            x = res_block(x, time_emb)
58            x = attn(x)
59            skips.append(x)  # Save before downsampling
60
61        x = self.downsample(x)
62        return x, skips

Key design decisions in the encoder:

  • Save skips after attention: Skip connections include the attended features
  • Downsample at the end: Process at current resolution, then reduce for next level
  • Channel increase at first ResBlock: Each level may have more channels than the previous

The Bottleneck

The bottleneck sits at the lowest resolution (typically 4x4 or 8x8 for 64x64 images). This is where the most abstract, compressed representation exists. It typically contains:

  • Multiple ResBlocks for deep feature processing
  • Attention layers (computationally cheap at low resolution)
  • No downsampling or upsampling
🐍python
1class Bottleneck(nn.Module):
2    """
3    Middle block at the lowest resolution.
4    Contains ResBlocks with attention for deep processing.
5    """
6    def __init__(
7        self,
8        channels: int,
9        time_emb_dim: int,
10        dropout: float = 0.0,
11    ):
12        super().__init__()
13
14        self.block1 = ResBlock(channels, channels, time_emb_dim, dropout)
15        self.attention = AttentionBlock(channels)
16        self.block2 = ResBlock(channels, channels, time_emb_dim, dropout)
17
18    def forward(
19        self, x: torch.Tensor, time_emb: torch.Tensor
20    ) -> torch.Tensor:
21        x = self.block1(x, time_emb)
22        x = self.attention(x)
23        x = self.block2(x, time_emb)
24        return x

Why Attention at the Bottleneck?

At 4x4 or 8x8 resolution, self-attention is cheap (16 or 64 tokens vs 4096 at 64x64). This is the ideal place for the network to reason about global image structure, understanding relationships between distant parts of the image that are now represented as nearby spatial locations in the compressed feature map.

Decoder Implementation

The decoder mirrors the encoder but works in reverse: it upsamples and concatenates skip connections at each level.

Decoder: The Expanding Path
🐍decoder.py
1Decoder Module

The decoder is the expanding path. It progressively upsamples while concatenating skip connections from the encoder to recover fine details.

10Skip Connection Handling

Skip connections are crucial for diffusion quality. The decoder receives feature maps from corresponding encoder levels and concatenates them before processing.

16Channel Doubling from Skips

When concatenating encoder features (skip connections), the channel count doubles. The first convolution in each decoder level handles this increased channel count.

22ResBlocks with Concatenated Features

After concatenation, ResBlocks process the combined features. This allows the decoder to leverage both high-level (from bottleneck) and low-level (from encoder) information.

28Upsampling

After processing at each level, we upsample by 2x to increase spatial resolution. This continues until we reach the original image size.

60 lines without explanation
1class DecoderBlock(nn.Module):
2    """
3    Single decoder level: Upsample + Concat skips + ResBlocks + Attention.
4    """
5    def __init__(
6        self,
7        in_channels: int,
8        out_channels: int,
9        skip_channels: int,  # Channels from encoder skip connections
10        time_emb_dim: int,
11        num_res_blocks: int = 2,
12        use_attention: bool = False,
13        upsample: bool = True,
14        dropout: float = 0.0,
15    ):
16        super().__init__()
17
18        # Upsample first
19        self.upsample = Upsample(in_channels) if upsample else nn.Identity()
20
21        self.res_blocks = nn.ModuleList()
22        self.attention_blocks = nn.ModuleList()
23
24        # Each ResBlock receives concatenated features
25        # First block: upsample_channels + skip_channels -> out_channels
26        self.res_blocks.append(
27            ResBlock(in_channels + skip_channels, out_channels, time_emb_dim, dropout)
28        )
29        if use_attention:
30            self.attention_blocks.append(AttentionBlock(out_channels))
31        else:
32            self.attention_blocks.append(nn.Identity())
33
34        # Subsequent blocks: out_channels + skip_channels -> out_channels
35        for _ in range(num_res_blocks - 1):
36            self.res_blocks.append(
37                ResBlock(out_channels + skip_channels, out_channels, time_emb_dim, dropout)
38            )
39            if use_attention:
40                self.attention_blocks.append(AttentionBlock(out_channels))
41            else:
42                self.attention_blocks.append(nn.Identity())
43
44    def forward(
45        self,
46        x: torch.Tensor,
47        skips: List[torch.Tensor],
48        time_emb: torch.Tensor,
49    ) -> torch.Tensor:
50        """
51        Args:
52            x: Features from previous decoder level (or bottleneck)
53            skips: Skip connections from encoder (in reverse order!)
54            time_emb: Time embedding
55        """
56        x = self.upsample(x)
57
58        for i, (res_block, attn) in enumerate(zip(self.res_blocks, self.attention_blocks)):
59            # Pop skip connection (reverse order)
60            skip = skips.pop()
61            x = torch.cat([x, skip], dim=1)  # Concatenate along channels
62            x = res_block(x, time_emb)
63            x = attn(x)
64
65        return x

Skip Connection Alignment

The decoder must receive skip connections in the reverse order they were saved. If the encoder saved [skip_32x32, skip_16x16, skip_8x8], the decoder needs them as [skip_8x8, skip_16x16, skip_32x32]. This is typically handled by reversing the list or using pop() to consume from the end.

Complete U-Net Class

Now let's assemble everything into a complete, production-ready U-Net:

Complete Diffusion U-Net
🐍unet.py
1U-Net Class Definition

The main U-Net class orchestrates all components: time embedding, encoder, bottleneck, decoder, and output projection. It manages skip connections internally.

15Configuration Parameters

Key architecture choices: image size, base channels, channel multipliers per level, number of ResBlocks, attention resolutions, and dropout rate.

25Time Embedding Network

Sinusoidal embeddings are created and projected through an MLP. This creates a time-dependent modulation vector used throughout the network.

35Initial Convolution

Projects the 3-channel RGB input to the base channel dimension. This is where the image enters the U-Net feature space.

40Building the Encoder

We construct encoder levels programmatically, keeping track of output channels at each level for the decoder skip connections.

52Bottleneck

The bottleneck at the lowest resolution contains ResBlocks and typically attention. This is where the most compressed representation exists.

60Building the Decoder

The decoder mirrors the encoder but in reverse order. It receives skip connections from corresponding encoder levels.

72Output Projection

Final GroupNorm, SiLU, and convolution to project back to 3 channels (or input channels). This produces the predicted noise.

80Forward Pass

The forward pass: embed time, initial conv, encode (saving skips), bottleneck, decode (using skips), and output projection.

88Encoder with Skip Storage

During encoding, we save intermediate features as skip connections. These are stored in a list and later popped in reverse order by the decoder.

98Decoder with Skip Retrieval

The decoder pops skip connections and concatenates them with upsampled features. The pop() ensures we use them in reverse order (matching resolutions).

154 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import List, Optional
6
7class UNet(nn.Module):
8    """
9    U-Net architecture for diffusion models.
10
11    Combines:
12    - Sinusoidal time embeddings with MLP projection
13    - Encoder with ResBlocks, optional attention, and downsampling
14    - Bottleneck with attention
15    - Decoder with skip connections, ResBlocks, and upsampling
16    - Output projection to noise prediction
17    """
18
19    def __init__(
20        self,
21        image_size: int = 64,
22        in_channels: int = 3,
23        out_channels: int = 3,
24        base_channels: int = 128,
25        channel_mults: tuple = (1, 2, 3, 4),
26        num_res_blocks: int = 2,
27        attention_resolutions: tuple = (16, 8),
28        dropout: float = 0.0,
29        time_emb_dim: Optional[int] = None,
30    ):
31        super().__init__()
32
33        self.image_size = image_size
34        time_emb_dim = time_emb_dim or base_channels * 4
35
36        # === Time Embedding ===
37        self.time_embed = nn.Sequential(
38            SinusoidalPositionEmbeddings(base_channels),
39            nn.Linear(base_channels, time_emb_dim),
40            nn.SiLU(),
41            nn.Linear(time_emb_dim, time_emb_dim),
42        )
43
44        # === Initial Convolution ===
45        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
46
47        # === Encoder ===
48        self.encoder_blocks = nn.ModuleList()
49        self.encoder_channels = []  # Track channels for decoder
50
51        ch = base_channels
52        current_res = image_size
53
54        for level, mult in enumerate(channel_mults):
55            out_ch = base_channels * mult
56            is_last = level == len(channel_mults) - 1
57
58            block = EncoderBlock(
59                in_channels=ch,
60                out_channels=out_ch,
61                time_emb_dim=time_emb_dim,
62                num_res_blocks=num_res_blocks,
63                use_attention=current_res in attention_resolutions,
64                downsample=not is_last,
65                dropout=dropout,
66            )
67            self.encoder_blocks.append(block)
68
69            # Track output channels (for skip connections)
70            for _ in range(num_res_blocks):
71                self.encoder_channels.append(out_ch)
72
73            ch = out_ch
74            if not is_last:
75                current_res //= 2
76
77        # === Bottleneck ===
78        self.bottleneck = Bottleneck(ch, time_emb_dim, dropout)
79
80        # === Decoder ===
81        self.decoder_blocks = nn.ModuleList()
82
83        for level, mult in reversed(list(enumerate(channel_mults))):
84            out_ch = base_channels * mult
85            is_last = level == 0
86
87            block = DecoderBlock(
88                in_channels=ch,
89                out_channels=out_ch,
90                skip_channels=self.encoder_channels.pop(),
91                time_emb_dim=time_emb_dim,
92                num_res_blocks=num_res_blocks,
93                use_attention=current_res in attention_resolutions,
94                upsample=not is_last,
95                dropout=dropout,
96            )
97            self.decoder_blocks.append(block)
98
99            ch = out_ch
100            if not is_last:
101                current_res *= 2
102
103        # === Output ===
104        self.out_norm = nn.GroupNorm(32, ch)
105        self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1)
106
107        # Initialize output to zero
108        nn.init.zeros_(self.out_conv.weight)
109        nn.init.zeros_(self.out_conv.bias)
110
111    def forward(
112        self, x: torch.Tensor, t: torch.Tensor
113    ) -> torch.Tensor:
114        """
115        Forward pass of U-Net.
116
117        Args:
118            x: Noisy input images [B, C, H, W]
119            t: Timesteps [B]
120
121        Returns:
122            Predicted noise [B, C, H, W]
123        """
124        # Time embedding
125        time_emb = self.time_embed(t)
126
127        # Initial convolution
128        h = self.init_conv(x)
129
130        # Encoder (collect skip connections)
131        skips = []
132        for block in self.encoder_blocks:
133            h, block_skips = block(h, time_emb)
134            skips.extend(block_skips)
135
136        # Bottleneck
137        h = self.bottleneck(h, time_emb)
138
139        # Decoder (consume skip connections)
140        for block in self.decoder_blocks:
141            h = block(h, skips, time_emb)
142
143        # Output
144        h = self.out_norm(h)
145        h = F.silu(h)
146        h = self.out_conv(h)
147
148        return h
149
150
151class SinusoidalPositionEmbeddings(nn.Module):
152    """Sinusoidal embeddings for timesteps."""
153
154    def __init__(self, dim: int):
155        super().__init__()
156        self.dim = dim
157
158    def forward(self, t: torch.Tensor) -> torch.Tensor:
159        device = t.device
160        half_dim = self.dim // 2
161        embeddings = math.log(10000) / (half_dim - 1)
162        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
163        embeddings = t[:, None] * embeddings[None, :]
164        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
165        return embeddings

Let's verify our U-Net works correctly:

🐍python
1# Test the U-Net
2model = UNet(
3    image_size=64,
4    in_channels=3,
5    out_channels=3,
6    base_channels=128,
7    channel_mults=(1, 2, 2, 4),
8    num_res_blocks=2,
9    attention_resolutions=(16, 8),
10)
11
12# Count parameters
13num_params = sum(p.numel() for p in model.parameters())
14print(f"Number of parameters: {num_params:,}")  # ~35M for this config
15
16# Test forward pass
17x = torch.randn(2, 3, 64, 64)  # Batch of 2 images
18t = torch.randint(0, 1000, (2,))  # Random timesteps
19
20noise_pred = model(x, t)
21print(f"Input shape: {x.shape}")
22print(f"Output shape: {noise_pred.shape}")  # Should match input!
23
24assert noise_pred.shape == x.shape, "U-Net should preserve spatial dimensions!"

Model Configuration

Different applications require different U-Net configurations. Here are common setups:

Common U-Net Configurations
🐍configs.py
1DDPM-Style Configuration

This configuration follows the original DDPM paper. It uses progressive channel multiplication and attention at specific resolutions.

5Channel Multipliers

The base channels (128) are multiplied at each level: 128, 256, 384, 512. This follows the common pattern of doubling (or near-doubling) channels as spatial size halves.

8Attention Resolutions

Attention is expensive at high resolutions. The standard practice is to only use attention at 16x16 and 8x8 resolutions where it is both feasible and beneficial.

12Improved DDPM Configuration

The Improved DDPM paper (Nichol & Dhariwal 2021) found that deeper networks with more attention perform better, at the cost of more computation.

54 lines without explanation
1# Configuration 1: DDPM-style (original paper)
2# Good for 64x64 images, ~35M parameters
3ddpm_config = dict(
4    image_size=64,
5    base_channels=128,
6    channel_mults=(1, 2, 2, 4),  # 128, 256, 256, 512
7    num_res_blocks=2,
8    attention_resolutions=(16,),  # Only at 16x16
9    dropout=0.1,
10)
11
12# Configuration 2: Improved DDPM (Nichol & Dhariwal 2021)
13# More capacity, better quality, ~113M parameters
14improved_config = dict(
15    image_size=64,
16    base_channels=192,
17    channel_mults=(1, 2, 3, 4),  # 192, 384, 576, 768
18    num_res_blocks=3,
19    attention_resolutions=(32, 16, 8),  # More attention
20    dropout=0.0,
21)
22
23# Configuration 3: Lightweight (for quick experiments)
24# Fast training, ~8M parameters
25light_config = dict(
26    image_size=64,
27    base_channels=64,
28    channel_mults=(1, 2, 4),  # 64, 128, 256
29    num_res_blocks=2,
30    attention_resolutions=(8,),
31    dropout=0.0,
32)
33
34# Configuration 4: High-resolution (256x256)
35# Needs more downsampling levels
36highres_config = dict(
37    image_size=256,
38    base_channels=128,
39    channel_mults=(1, 1, 2, 2, 4, 4),  # 6 levels
40    num_res_blocks=2,
41    attention_resolutions=(32, 16, 8),
42    dropout=0.0,
43)
44
45# Create models
46model_ddpm = UNet(**ddpm_config)
47model_improved = UNet(**improved_config)
48model_light = UNet(**light_config)
49model_highres = UNet(**highres_config)
50
51for name, model in [
52    ("DDPM", model_ddpm),
53    ("Improved", model_improved),
54    ("Light", model_light),
55    ("HighRes", model_highres),
56]:
57    params = sum(p.numel() for p in model.parameters()) / 1e6
58    print(f"{name}: {params:.1f}M parameters")
ConfigImage SizeParametersTraining TimeQuality
Light64x64~8MFastGood for prototyping
DDPM64x64~35MModeratePublication quality
Improved64x64~113MSlowState-of-the-art
HighRes256x256~100M+Very slowHigh-resolution

Choosing a Configuration

Start with the Light config for debugging and quick experiments. Move to DDPM for actual training runs. Use Improvedonly when you need maximum quality and have the compute budget.

Parameter Counting and Scaling

Understanding how parameters scale helps you design efficient architectures:

🐍python
1def count_parameters_by_component(model):
2    """Break down parameters by U-Net component."""
3
4    components = {
5        'time_embed': 0,
6        'init_conv': 0,
7        'encoder': 0,
8        'bottleneck': 0,
9        'decoder': 0,
10        'output': 0,
11    }
12
13    for name, param in model.named_parameters():
14        count = param.numel()
15        if 'time_embed' in name:
16            components['time_embed'] += count
17        elif 'init_conv' in name:
18            components['init_conv'] += count
19        elif 'encoder' in name:
20            components['encoder'] += count
21        elif 'bottleneck' in name:
22            components['bottleneck'] += count
23        elif 'decoder' in name:
24            components['decoder'] += count
25        else:
26            components['output'] += count
27
28    total = sum(components.values())
29    print("\nParameter breakdown:")
30    for component, count in components.items():
31        pct = 100 * count / total
32        print(f"  {component:12s}: {count:>10,} ({pct:5.1f}%)")
33    print(f"  {'TOTAL':12s}: {total:>10,}")
34
35    return components
36
37# Analyze the DDPM model
38model = UNet(**ddpm_config)
39count_parameters_by_component(model)
40
41# Output:
42# Parameter breakdown:
43#   time_embed  :    132,352 ( 0.4%)
44#   init_conv   :      3,584 ( 0.0%)
45#   encoder     : 17,234,432 (49.2%)
46#   bottleneck  :  2,363,392 ( 6.7%)
47#   decoder     : 15,234,816 (43.5%)
48#   output      :     66,560 ( 0.2%)
49#   TOTAL       : 35,035,136

Key observations about parameter distribution:

  • Encoder and decoder contain ~93% of parameters (nearly equal between them)
  • Time embedding is cheap (<1% of parameters)
  • Bottleneck is relatively small despite being computationally important
  • Attention layers (not shown separately) typically account for 10-20% of total parameters

Scaling Laws

How do parameters scale with architecture choices?

ChangeParameter ImpactQuality Impact
2x base_channels~4x parametersSignificant improvement
2x num_res_blocks~2x parametersModerate improvement
Add attention at 32x32~1.5x at that levelHelps global coherence
Remove attention at 8x8Minor reductionUsually safe to skip
Add resolution level~1.5-2x parametersRequired for higher res

Weight Initialization

Proper initialization is crucial for stable training. Diffusion models benefit from specific initialization strategies:

🐍python
1import torch.nn as nn
2
3def initialize_weights(model: nn.Module):
4    """
5    Initialize U-Net weights for stable diffusion training.
6
7    Key principles:
8    1. Most layers: Default PyTorch initialization (Kaiming/He)
9    2. Output layer: Initialize to zero (start by predicting zero noise)
10    3. Residual projections: Can be scaled down for stability
11    """
12
13    for name, module in model.named_modules():
14        if isinstance(module, nn.Conv2d):
15            # Default: Kaiming initialization for convolutions
16            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
17            if module.bias is not None:
18                nn.init.zeros_(module.bias)
19
20        elif isinstance(module, nn.Linear):
21            # Xavier/Glorot for linear layers
22            nn.init.xavier_uniform_(module.weight)
23            if module.bias is not None:
24                nn.init.zeros_(module.bias)
25
26        elif isinstance(module, nn.GroupNorm):
27            # Standard: gamma=1, beta=0
28            nn.init.ones_(module.weight)
29            nn.init.zeros_(module.bias)
30
31    # CRITICAL: Initialize output convolution to zero
32    # This means the model initially predicts zero noise,
33    # which is a reasonable starting point (especially for t close to 0)
34    if hasattr(model, 'out_conv'):
35        nn.init.zeros_(model.out_conv.weight)
36        nn.init.zeros_(model.out_conv.bias)
37
38    return model
39
40
41# Apply initialization
42model = UNet(**ddpm_config)
43model = initialize_weights(model)
44
45# Verify output is initially zero
46x = torch.randn(1, 3, 64, 64)
47t = torch.zeros(1, dtype=torch.long)
48with torch.no_grad():
49    output = model(x, t)
50    print(f"Initial output mean: {output.mean():.6f}")  # Should be ~0
51    print(f"Initial output std: {output.std():.6f}")    # Should be ~0

Zero Output Initialization

Initializing the output layer to zero is a common practice in diffusion models. It means the model starts by predicting "no noise", which is a safe starting point. During training, the model quickly learns to predict actual noise. This initialization prevents early training instabilities from random noise predictions.

Summary

In this section, we assembled a complete, production-ready U-Net for diffusion models:

  1. Architecture overview: The U-Net flows from input through encoder, bottleneck, decoder, and output, with skip connections bridging encoder to decoder
  2. Encoder implementation: Progressive downsampling with ResBlocks and attention, saving skip connections at each resolution
  3. Bottleneck design: Deep processing at lowest resolution with attention for global reasoning
  4. Decoder implementation: Progressive upsampling with skip concatenation to recover fine details
  5. Model configuration: Different configs for different use cases, from lightweight prototyping to publication-quality models
  6. Weight initialization: Zero output initialization for stable training

Chapter Complete!

Congratulations! You've completed Chapter 5: U-Net Architecture for Diffusion. You now understand every component of the diffusion U-Net and can implement it from scratch. In Chapter 6, we'll put this U-Net to work by building the complete diffusion model: the training loop, loss functions, and the forward diffusion process.

The complete U-Net code from this chapter is available in the repository and can be used as a foundation for your own diffusion model experiments. The architecture we've built supports:

  • Any image resolution (with appropriate channel multipliers)
  • Configurable depth and width
  • Flexible attention placement
  • Time-conditional generation
  • Skip connections for preserving details

In the next chapter, we'll train this U-Net to denoise images and generate new samples from pure noise!