Chapter 5
20 min read
Section 25 of 76

U-Net Building Blocks

U-Net Architecture for Diffusion

Learning Objectives

By the end of this section, you will:

  1. Implement ResBlocks with Group Normalization and residual connections
  2. Understand why GroupNorm is preferred over BatchNorm for diffusion models
  3. Master SiLU activation and why it outperforms ReLU for generation tasks
  4. Build downsampling and upsampling blocks for the encoder and decoder paths
  5. Combine blocks into complete encoder and decoder modules

Why This Matters

The ResBlock is the atomic unit of modern diffusion U-Nets. Understanding its components— normalization, activation, convolution, and residual connections—is essential for both implementing and debugging diffusion models. Each design choice has been carefully validated through extensive experimentation.

ResBlock Fundamentals

The ResBlock (Residual Block) is the core building block of diffusion U-Nets. It combines several key innovations:

  • Residual connections: Allow gradients to flow directly through the network
  • Normalization: Stabilize training by normalizing intermediate activations
  • Time conditioning: Inject timestep information to modulate network behavior
  • Nonlinearity: SiLU activations for smooth, non-saturating gradients

The general structure of a ResBlock for diffusion models is:

Output=ResBlock(x,t)=x+F(Norm(x),t)\text{Output} = \text{ResBlock}(x, t) = x + F(\text{Norm}(x), t)

where FF is the learned transformation and ttis the time embedding. The addition of xx is the residual (skip) connection.

Pre-activation vs Post-activation

Modern diffusion U-Nets use pre-activation ResBlocks: the normalization and activation come before the convolution, not after. This design, from the He et al. Identity Mappings paper, improves gradient flow and training stability.

Group Normalization

Normalization is critical for training deep networks. In diffusion models, we use Group Normalization (GroupNorm) instead of Batch Normalization.

Why Not BatchNorm?

BatchNorm computes statistics across the batch dimension, which has several problems for diffusion:

  • Small batch sizes: GPU memory limits batch sizes, making batch statistics noisy
  • Inference inconsistency: Running statistics differ from training, causing issues during sampling
  • Sample independence: Each image should be processed independently during generation

GroupNorm Explained

GroupNorm divides channels into GG groups and normalizes within each group, independently for each sample:

x^n,c,h,w=xn,c,h,wμn,gσn,g2+ϵγc+βc\hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma^2_{n,g} + \epsilon}} \cdot \gamma_c + \beta_c

where g=cG/Cg = \lfloor c \cdot G / C \rfloor is the group index for channel cc.

Understanding Group Normalization
🐍group_norm_example.py
1GroupNorm Overview

GroupNorm divides channels into groups and normalizes within each group. Unlike BatchNorm, it doesn't depend on batch size.

7Compute Group Statistics

For each group, compute mean and variance across spatial dimensions. This is done independently per sample in the batch.

13Normalize

Subtract mean and divide by standard deviation (with epsilon for numerical stability). This centers and scales the features.

18Affine Transform

Apply learnable scale (gamma) and shift (beta) parameters. This allows the network to undo the normalization if needed.

18 lines without explanation
1# GroupNorm normalizes channels in groups, not across batch
2
3import torch
4import torch.nn as nn
5
6# Example: 512 channels, 32 groups = 16 channels per group
7group_norm = nn.GroupNorm(num_groups=32, num_channels=512)
8
9# Input shape: [batch, channels, height, width]
10x = torch.randn(4, 512, 32, 32)
11
12# For each sample independently:
13# 1. Reshape to [batch, num_groups, channels_per_group, H, W]
14# 2. Compute mean and variance within each group
15# 3. Normalize
16# 4. Apply learnable scale and shift
17
18y = group_norm(x)  # Same shape: [4, 512, 32, 32]
19
20# Verify: normalized within groups, not batch
21print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
22print(f"Output mean: {y.mean():.4f}, std: {y.std():.4f}")
NormalizationComputed OverBest ForBatch Size Sensitivity
BatchNormN, H, W (per channel)Classification, fixed batchHigh
LayerNormC, H, W (per sample)Transformers, NLPNone
InstanceNormH, W (per channel, sample)Style transferNone
GroupNormGroups of C, H, W (per sample)Diffusion, small batchNone

Choosing Number of Groups

The standard choice is 32 groups. This works well for most channel counts (64, 128, 256, 512). The number of channels must be divisible by the number of groups. For very small channel counts, use min(32,C)\min(32, C) groups.

Activation Functions

Modern diffusion models use SiLU (Sigmoid Linear Unit), also known as Swish, instead of ReLU:

SiLU(x)=xσ(x)=x1+ex\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}

Why SiLU Over ReLU?

PropertyReLUSiLU
Formulamax(0, x)x * sigmoid(x)
SmoothnessNot smooth at 0Smooth everywhere
Gradient at x=0Undefined0.5
Negative valuesAlways 0Slightly negative
Self-gatingNoYes (sigmoid gates)

The key advantages of SiLU for diffusion:

  • Smooth gradients: No sharp transition at zero, leading to more stable training
  • Non-monotonic: Can output small negative values, enabling richer representations
  • Self-gating: The sigmoid modulates the linear part, similar to attention mechanisms
🐍python
1import torch
2import torch.nn.functional as F
3
4# SiLU activation (built into PyTorch)
5x = torch.randn(4, 64, 32, 32)
6
7# Method 1: Using nn.SiLU module
8silu = torch.nn.SiLU()
9y = silu(x)
10
11# Method 2: Using functional API
12y = F.silu(x)
13
14# Method 3: Manual implementation
15y = x * torch.sigmoid(x)
16
17# All three are equivalent!

GELU vs SiLU

GELU (Gaussian Error Linear Unit) is another popular smooth activation used in Transformers. While similar to SiLU, most diffusion models stick with SiLU following the original DDPM paper. Both work well in practice.

Convolution Layers

The convolutions in diffusion U-Nets are standard 2D convolutions with specific choices:

  • Kernel size: 3x3 is the standard choice, providing a good balance between receptive field and computational cost
  • Padding: padding=1 preserves spatial dimensions (for 3x3 kernels)
  • Stride: stride=1 for feature processing, stride=2 for downsampling
  • Bias: Often disabled when followed by normalization (which has its own bias)
🐍python
1import torch.nn as nn
2
3# Standard 3x3 convolution preserving spatial size
4conv = nn.Conv2d(
5    in_channels=128,
6    out_channels=256,
7    kernel_size=3,
8    stride=1,
9    padding=1,
10    bias=False  # Disabled when using GroupNorm
11)
12
13# Spatial dimensions: H_out = H_in for padding=1, stride=1
14
15# 1x1 convolution for channel mixing / projection
16proj = nn.Conv2d(128, 256, kernel_size=1, bias=False)
17
18# Strided convolution for downsampling (reduces size by 2)
19downsample = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)

Complete ResBlock Implementation

Now let's implement a complete ResBlock for diffusion models with time conditioning:

Production-Ready ResBlock for Diffusion
🐍resblock.py
1Imports

We import PyTorch modules for building neural network layers. nn.functional provides operations like SiLU activation.

5ResBlock Class Definition

The ResBlock is the fundamental building block of diffusion U-Nets. Each block processes features while maintaining gradient flow through residual connections.

16First Normalization Layer

GroupNorm normalizes features across channel groups. We use 32 groups as the standard choice, which balances computation and stability.

19First Convolution

A 3x3 convolution processes the normalized features. padding=1 ensures the spatial dimensions are preserved.

22Time Embedding Projection

The time embedding is projected to match the channel dimension. This allows the network to modulate its behavior based on the noise level.

28Second Normalization and Convolution

The pattern repeats: normalize, then convolve. This double-convolution structure provides sufficient capacity for feature transformation.

35Skip Connection Projection

If input and output channels differ, we need a 1x1 convolution to match dimensions for the residual addition.

42Forward Pass - Normalization and Activation

First, apply GroupNorm followed by SiLU activation. SiLU (Swish) is preferred over ReLU for smoother gradients.

46Add Time Embedding

The time embedding is added to the features after the first convolution. This is how the network learns time-dependent behavior.

EXAMPLE
h = h + time_emb[:, :, None, None] broadcasts to spatial dims
52Residual Connection

The skip connection adds the input (possibly projected) to the output. This enables training of very deep networks by allowing gradients to flow directly.

63 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class ResBlock(nn.Module):
6    """
7    Residual block with time conditioning for diffusion models.
8
9    Architecture:
10    1. GroupNorm -> SiLU -> Conv3x3
11    2. + time_embedding (projected)
12    3. GroupNorm -> SiLU -> Dropout -> Conv3x3
13    4. + skip_connection (optionally projected)
14    """
15
16    def __init__(
17        self,
18        in_channels: int,
19        out_channels: int,
20        time_emb_dim: int,
21        dropout: float = 0.0,
22        num_groups: int = 32,
23    ):
24        super().__init__()
25
26        # First normalization and convolution
27        self.norm1 = nn.GroupNorm(num_groups, in_channels)
28        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
29
30        # Time embedding projection
31        self.time_mlp = nn.Sequential(
32            nn.SiLU(),
33            nn.Linear(time_emb_dim, out_channels),
34        )
35
36        # Second normalization and convolution
37        self.norm2 = nn.GroupNorm(num_groups, out_channels)
38        self.dropout = nn.Dropout(dropout)
39        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
40
41        # Skip connection projection if channels change
42        if in_channels != out_channels:
43            self.skip_proj = nn.Conv2d(in_channels, out_channels, 1)
44        else:
45            self.skip_proj = nn.Identity()
46
47    def forward(
48        self, x: torch.Tensor, time_emb: torch.Tensor
49    ) -> torch.Tensor:
50        """
51        Args:
52            x: Input features [B, C_in, H, W]
53            time_emb: Time embedding [B, time_emb_dim]
54        Returns:
55            Output features [B, C_out, H, W]
56        """
57        # First block: Norm -> Activation -> Conv
58        h = self.norm1(x)
59        h = F.silu(h)
60        h = self.conv1(h)
61
62        # Add time embedding (broadcast to spatial dimensions)
63        time_emb = self.time_mlp(time_emb)
64        h = h + time_emb[:, :, None, None]
65
66        # Second block: Norm -> Activation -> Dropout -> Conv
67        h = self.norm2(h)
68        h = F.silu(h)
69        h = self.dropout(h)
70        h = self.conv2(h)
71
72        # Residual connection
73        return h + self.skip_proj(x)

Time Embedding Broadcasting

The time embedding has shape [B,D][B, D] but we need to add it to features of shape [B,C,H,W][B, C, H, W]. The trick is to add dimensions:time_emb[:, :, None, None] broadcasts to all spatial locations. This means the same time modulation applies uniformly across the image.

Downsampling Blocks

The encoder path of U-Net reduces spatial resolution at each level. There are two common approaches:

Downsampling Options
🐍downsample.py
1Downsample Block

Reduces spatial resolution by 2x. This is used in the encoder path of the U-Net to create the contracting structure.

8Strided Convolution Option

A 3x3 convolution with stride=2 reduces resolution while learning the downsampling. This is more flexible than pooling.

12Average Pooling Option

Alternatively, average pooling followed by a 1x1 convolution. This is computationally cheaper and sometimes works equally well.

37 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class Downsample(nn.Module):
6    """Reduce spatial resolution by 2x."""
7
8    def __init__(
9        self,
10        channels: int,
11        use_conv: bool = True,
12    ):
13        super().__init__()
14
15        if use_conv:
16            # Strided convolution - learnable downsampling
17            self.down = nn.Conv2d(
18                channels, channels,
19                kernel_size=3, stride=2, padding=1
20            )
21        else:
22            # Average pooling - fixed downsampling
23            self.down = nn.Sequential(
24                nn.AvgPool2d(kernel_size=2, stride=2),
25                nn.Conv2d(channels, channels, kernel_size=1),
26            )
27
28    def forward(self, x: torch.Tensor) -> torch.Tensor:
29        return self.down(x)
30
31
32# Alternative: Interpolation-based downsampling
33class InterpolateDownsample(nn.Module):
34    def __init__(self, channels: int):
35        super().__init__()
36        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
37
38    def forward(self, x: torch.Tensor) -> torch.Tensor:
39        x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
40        return self.conv(x)

Strided Conv vs Pooling

Strided convolution (stride=2) is more common in modern architectures because it's learnable and combines downsampling with feature transformation. Pooling is fixed and may discard useful information. Most diffusion models use strided convolution.

Upsampling Blocks

The decoder path increases spatial resolution. Again, there are multiple approaches:

Upsampling Options
🐍upsample.py
1Upsample Block

Increases spatial resolution by 2x. This is used in the decoder path to expand back to the original resolution.

8Interpolation + Convolution

First, nearest-neighbor or bilinear interpolation doubles the spatial size. Then, a convolution refines the upsampled features.

14Transposed Convolution Alternative

ConvTranspose2d learns the upsampling. It can produce checkerboard artifacts if not carefully initialized, but is more flexible.

43 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class Upsample(nn.Module):
6    """Increase spatial resolution by 2x."""
7
8    def __init__(
9        self,
10        channels: int,
11        use_conv_transpose: bool = False,
12    ):
13        super().__init__()
14
15        if use_conv_transpose:
16            # Transposed convolution - learnable upsampling
17            # Note: Can produce checkerboard artifacts!
18            self.up = nn.ConvTranspose2d(
19                channels, channels,
20                kernel_size=4, stride=2, padding=1
21            )
22        else:
23            # Interpolation + convolution - cleaner results
24            self.up = nn.Sequential(
25                nn.Upsample(scale_factor=2, mode="nearest"),
26                nn.Conv2d(channels, channels, kernel_size=3, padding=1),
27            )
28
29    def forward(self, x: torch.Tensor) -> torch.Tensor:
30        return self.up(x)
31
32
33# Recommended approach for diffusion
34class InterpolateUpsample(nn.Module):
35    """
36    Interpolation followed by convolution.
37    This avoids checkerboard artifacts from transposed convolution.
38    """
39    def __init__(self, channels: int, out_channels: int = None):
40        super().__init__()
41        out_channels = out_channels or channels
42        self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1)
43
44    def forward(self, x: torch.Tensor) -> torch.Tensor:
45        x = F.interpolate(x, scale_factor=2, mode="nearest")
46        return self.conv(x)

Checkerboard Artifacts

Transposed convolutions with even kernel sizes can produce "checkerboard" patterns in the output due to uneven overlapping of the kernel. The safer approach is interpolation + convolution: first upsample with nearest-neighbor or bilinear interpolation, then refine with a regular convolution.

Summary

In this section, we implemented the fundamental building blocks of diffusion U-Nets:

  1. ResBlock: The core module combining normalization, activation, convolution, time conditioning, and residual connections
  2. GroupNorm: Normalization that works with any batch size, essential for diffusion training and inference
  3. SiLU activation: Smooth, self-gating activation that improves gradient flow compared to ReLU
  4. Downsampling: Strided convolutions or pooling to reduce spatial resolution in the encoder
  5. Upsampling: Interpolation + convolution to increase resolution in the decoder while avoiding artifacts

Coming Up Next

In the next section, we'll implement time conditioning: how to encode the timestep tt into a vector and inject it into every ResBlock. This is what allows the network to adapt its behavior from removing heavy noise to fine-tuning details.