Chapter 5
20 min read
Section 27 of 76

Attention in Diffusion U-Net

U-Net Architecture for Diffusion

Learning Objectives

By the end of this section, you will:

  1. Understand why attention is crucial for high-quality diffusion models
  2. Adapt self-attention for images by treating spatial positions as tokens
  3. Implement multi-head attention for diverse attention patterns
  4. Know where to place attention in the U-Net architecture
  5. Balance quality vs computation by strategic attention placement

Why This Matters

Attention enables global context aggregation that convolutions cannot achieve efficiently. A 3x3 convolution only sees 9 pixels; attention can relate any pixel to any other pixel in the image. For diffusion models, this is critical for generating coherent global structure while maintaining local details.

Why Attention in U-Net?

Convolutions are local operations: each output pixel depends only on a small neighborhood of input pixels. To understand an entire image, information must propagate through many layers, which is slow and can lead to information loss.

Self-attention is a global operation: each output pixel can directly attend to every input pixel in a single layer. This enables:

  • Long-range dependencies: A face can "see" the background to ensure color consistency
  • Global coherence: All parts of an object can coordinate their appearance
  • Semantic reasoning: The network can relate semantically similar regions
PropertyConvolutionSelf-Attention
Receptive fieldLocal (kernel size)Global (entire image)
ComputationO(H * W * k^2 * C^2)O((H*W)^2 * C)
MemoryO(H * W * C)O((H*W)^2)
Translation equivarianceYes (built-in)No (learned)
Parameter sharingAcross positionsAcross positions

Computational Cost

Self-attention has O((HW)2)O((HW)^2) complexity in memory and computation. For a 256x256 image, that's 65,536 tokens and over 4 billion attention pairs! This is why we only use attention at lower resolutions (16x16, 32x32) in U-Net.

Self-Attention Review

Self-attention from Transformers computes weighted combinations of values based on query-key similarity:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

where:

  • Q (Queries): What am I looking for?
  • K (Keys): What do I contain?
  • V (Values): What information do I have?
  • d_k: Key dimension (for scaling)

The softmax creates a probability distribution over positions, and the output is a weighted sum of values based on these probabilities.

Intuition

Think of attention as a soft dictionary lookup. The query asks "what relevant information is in this image?" Keys answer "here's what each position contains" and values provide the actual content to aggregate.

Spatial Self-Attention for Images

To apply self-attention to images, we treat each spatial position (h,w)(h, w)as a token with a CC-dimensional embedding:

  1. Flatten: Reshape from [B,C,H,W][B, C, H, W] to[B,C,HW][B, C, H \cdot W]
  2. Transpose: Get [B,HW,C][B, H \cdot W, C] for standard attention format
  3. Attention: Compute attention over the HWH \cdot W tokens
  4. Reshape: Back to [B,C,H,W][B, C, H, W]
Spatial Self-Attention Implementation
🐍spatial_attention.py
1Import Modules

We use PyTorch nn module for building layers and einops for tensor reshaping operations.

5SpatialSelfAttention Class

Self-attention adapted for 2D images. We treat each spatial location (h, w) as a token and compute attention across all locations.

14QKV Projections

Three separate 1x1 convolutions project the input to Query, Key, and Value. Using convolutions instead of linear layers preserves the spatial structure.

20Output Projection

After attention, project back to the original channel dimension. This allows the attention block to be used as a residual module.

25Normalization

GroupNorm before attention stabilizes training. This is the pre-norm pattern common in modern Transformers.

32Compute QKV

Apply the projections to get queries, keys, and values. Each has shape [B, C, H, W].

38Flatten Spatial Dimensions

Reshape from [B, C, H, W] to [B, C, H*W]. Now each of the H*W positions is a token with C-dimensional embedding.

43Attention Scores

Compute Q @ K^T / sqrt(d) to get attention weights. The scaling prevents softmax saturation for large dimensions.

EXAMPLE
attn[i,j] = softmax(q_i @ k_j / sqrt(C))
48Apply Attention

Multiply attention weights by values: output = softmax(QK^T/sqrt(d)) @ V. Each position aggregates information from all other positions.

53Reshape and Residual

Reshape back to [B, C, H, W] and add residual connection. The attention output is added to the input, not replaced.

45 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class SpatialSelfAttention(nn.Module):
6    """
7    Self-attention for 2D feature maps.
8
9    Each spatial position attends to all other positions,
10    enabling global context aggregation.
11    """
12
13    def __init__(self, channels: int, num_groups: int = 32):
14        super().__init__()
15
16        # QKV projections using 1x1 convolutions
17        self.q_proj = nn.Conv2d(channels, channels, 1)
18        self.k_proj = nn.Conv2d(channels, channels, 1)
19        self.v_proj = nn.Conv2d(channels, channels, 1)
20
21        # Output projection
22        self.out_proj = nn.Conv2d(channels, channels, 1)
23
24        # Pre-norm
25        self.norm = nn.GroupNorm(num_groups, channels)
26
27        self.scale = channels ** -0.5
28
29    def forward(self, x: torch.Tensor) -> torch.Tensor:
30        B, C, H, W = x.shape
31
32        # Normalize and compute QKV
33        h = self.norm(x)
34        q = self.q_proj(h)
35        k = self.k_proj(h)
36        v = self.v_proj(h)
37
38        # Flatten spatial dimensions: [B, C, H, W] -> [B, C, H*W]
39        q = q.view(B, C, H * W)
40        k = k.view(B, C, H * W)
41        v = v.view(B, C, H * W)
42
43        # Attention: [B, H*W, H*W]
44        attn = torch.bmm(q.transpose(1, 2), k) * self.scale
45        attn = F.softmax(attn, dim=-1)
46
47        # Apply attention to values: [B, C, H*W]
48        out = torch.bmm(v, attn.transpose(1, 2))
49
50        # Reshape back to spatial
51        out = out.view(B, C, H, W)
52
53        # Output projection and residual
54        out = self.out_proj(out)
55        return x + out

Multi-Head Attention

Multi-head attention splits the channels into multiple "heads", each computing attention independently:

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O

where each head is:headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

Multi-Head Self-Attention
🐍multihead_attention.py
1Multi-Head Attention

Split channels into multiple heads, each performing independent attention. This allows the model to attend to different types of patterns simultaneously.

12Head Dimension

Each head operates on channels/num_heads dimensions. For 256 channels with 8 heads, each head has 32 dimensions.

17Reshape to Heads

Reshape [B, C, H*W] to [B, num_heads, head_dim, H*W]. Each head processes a subset of channels independently.

24Per-Head Attention

Attention is computed for each head separately. Different heads can learn different attention patterns (e.g., local vs global).

30Concatenate Heads

Merge the outputs from all heads back to [B, C, H*W]. The output projection then mixes information across heads.

49 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class MultiHeadSpatialAttention(nn.Module):
6    """
7    Multi-head self-attention for 2D feature maps.
8
9    Multiple heads allow attending to different patterns simultaneously.
10    """
11
12    def __init__(
13        self,
14        channels: int,
15        num_heads: int = 8,
16        num_groups: int = 32,
17    ):
18        super().__init__()
19
20        assert channels % num_heads == 0
21        self.num_heads = num_heads
22        self.head_dim = channels // num_heads
23
24        # QKV projection (all heads at once)
25        self.qkv_proj = nn.Conv2d(channels, channels * 3, 1)
26        self.out_proj = nn.Conv2d(channels, channels, 1)
27        self.norm = nn.GroupNorm(num_groups, channels)
28
29        self.scale = self.head_dim ** -0.5
30
31    def forward(self, x: torch.Tensor) -> torch.Tensor:
32        B, C, H, W = x.shape
33
34        # Normalize and compute QKV
35        h = self.norm(x)
36        qkv = self.qkv_proj(h)  # [B, 3*C, H, W]
37
38        # Split into Q, K, V
39        qkv = qkv.view(B, 3, self.num_heads, self.head_dim, H * W)
40        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
41        # Each: [B, num_heads, head_dim, H*W]
42
43        # Attention per head
44        attn = torch.einsum("bhdn,bhdm->bhnm", q, k) * self.scale
45        attn = F.softmax(attn, dim=-1)
46
47        # Apply attention
48        out = torch.einsum("bhnm,bhdm->bhdn", attn, v)
49
50        # Merge heads: [B, num_heads, head_dim, H*W] -> [B, C, H, W]
51        out = out.reshape(B, C, H, W)
52        out = self.out_proj(out)
53
54        return x + out

Number of Heads

Common choices: 4, 8, or 16 heads. More heads = more diverse attention patterns, but each head has fewer dimensions. A good rule: keep head_dim at least 32-64 for sufficient capacity.

Where to Place Attention

Due to the O((HW)2)O((HW)^2) cost, we cannot use attention at every layer. The standard strategy:

ResolutionTokensAttention CostUse Attention?
256x25665,5364.3 billion opsNo - too expensive
128x12816,384268 million opsOptional
64x644,09616.8 million opsOptional
32x321,0241 million opsYes - affordable
16x1625665,536 opsYes - cheap and critical

The common pattern in diffusion U-Nets:

  • Bottleneck (lowest resolution): Always use attention here for global reasoning
  • 16x16 and 32x32 levels: Use attention for semantic coherence
  • 64x64 and above: Skip attention, rely on convolutions for local patterns

Architecture Variants

Some models (like Stable Diffusion) use attention at multiple resolutions. Others (like DDPM) only use attention at the lowest resolution. The trade-off is quality vs speed vs memory.

Complete Attention Block

Here's a production-ready attention block that can be inserted into U-Net:

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class AttentionBlock(nn.Module):
6    """
7    Complete attention block for diffusion U-Net.
8
9    Includes:
10    - Pre-normalization (GroupNorm)
11    - Multi-head self-attention
12    - Residual connection
13    """
14
15    def __init__(
16        self,
17        channels: int,
18        num_heads: int = 8,
19        num_groups: int = 32,
20    ):
21        super().__init__()
22
23        self.num_heads = num_heads
24        self.head_dim = channels // num_heads
25
26        # Pre-norm
27        self.norm = nn.GroupNorm(num_groups, channels)
28
29        # QKV and output projections
30        self.qkv = nn.Conv2d(channels, channels * 3, 1, bias=False)
31        self.proj = nn.Conv2d(channels, channels, 1)
32
33        # Initialize output projection to zero for stable training
34        nn.init.zeros_(self.proj.weight)
35
36        self.scale = self.head_dim ** -0.5
37
38    def forward(self, x: torch.Tensor) -> torch.Tensor:
39        B, C, H, W = x.shape
40        n_tokens = H * W
41
42        # Pre-norm
43        h = self.norm(x)
44
45        # QKV projection: [B, 3*C, H, W]
46        qkv = self.qkv(h)
47
48        # Reshape for multi-head attention
49        # [B, 3*C, H, W] -> [B, 3, num_heads, head_dim, H*W]
50        qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, n_tokens)
51
52        # Split Q, K, V
53        q = qkv[:, 0]  # [B, heads, dim, tokens]
54        k = qkv[:, 1]
55        v = qkv[:, 2]
56
57        # Compute attention: [B, heads, tokens, tokens]
58        attn = torch.einsum("bhdn,bhdm->bhnm", q, k) * self.scale
59        attn = F.softmax(attn, dim=-1)
60
61        # Apply to values: [B, heads, dim, tokens]
62        out = torch.einsum("bhnm,bhdm->bhdn", attn, v)
63
64        # Reshape: [B, C, H, W]
65        out = out.reshape(B, C, H, W)
66
67        # Output projection
68        out = self.proj(out)
69
70        # Residual
71        return x + out
72
73
74# Usage in U-Net
75class UNetBlockWithAttention(nn.Module):
76    def __init__(self, channels: int, time_emb_dim: int, use_attention: bool = True):
77        super().__init__()
78
79        self.res_block = ResBlock(channels, channels, time_emb_dim)
80
81        if use_attention:
82            self.attn = AttentionBlock(channels)
83        else:
84            self.attn = nn.Identity()
85
86    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
87        x = self.res_block(x, time_emb)
88        x = self.attn(x)
89        return x

Zero Initialization

Notice we initialize the output projection to zeros. This makes the attention block act like identity at initialization, which helps training stability. The network can gradually learn to use attention as needed.

Summary

In this section, we added attention to our diffusion U-Net:

  1. Why attention: Enables global context that convolutions cannot efficiently capture
  2. Spatial self-attention: Treat each pixel as a token, compute attention over all positions
  3. Multi-head attention: Multiple attention patterns learned simultaneously
  4. Strategic placement: Use attention at low resolutions (16x16, 32x32) to manage computational cost
  5. Implementation details: Pre-normalization, residual connections, zero initialization

Coming Up Next

In the final section of this chapter, we'll assemble all the pieces—ResBlocks, time conditioning, downsampling, upsampling, skip connections, and attention—into a complete, production-ready U-Net for diffusion models.