Learning Objectives
By the end of this section, you will:
- Understand why attention is crucial for high-quality diffusion models
- Adapt self-attention for images by treating spatial positions as tokens
- Implement multi-head attention for diverse attention patterns
- Know where to place attention in the U-Net architecture
- Balance quality vs computation by strategic attention placement
Why This Matters
Why Attention in U-Net?
Convolutions are local operations: each output pixel depends only on a small neighborhood of input pixels. To understand an entire image, information must propagate through many layers, which is slow and can lead to information loss.
Self-attention is a global operation: each output pixel can directly attend to every input pixel in a single layer. This enables:
- Long-range dependencies: A face can "see" the background to ensure color consistency
- Global coherence: All parts of an object can coordinate their appearance
- Semantic reasoning: The network can relate semantically similar regions
| Property | Convolution | Self-Attention |
|---|---|---|
| Receptive field | Local (kernel size) | Global (entire image) |
| Computation | O(H * W * k^2 * C^2) | O((H*W)^2 * C) |
| Memory | O(H * W * C) | O((H*W)^2) |
| Translation equivariance | Yes (built-in) | No (learned) |
| Parameter sharing | Across positions | Across positions |
Computational Cost
Self-Attention Review
Self-attention from Transformers computes weighted combinations of values based on query-key similarity:
where:
- Q (Queries): What am I looking for?
- K (Keys): What do I contain?
- V (Values): What information do I have?
- d_k: Key dimension (for scaling)
The softmax creates a probability distribution over positions, and the output is a weighted sum of values based on these probabilities.
Intuition
Spatial Self-Attention for Images
To apply self-attention to images, we treat each spatial position as a token with a -dimensional embedding:
- Flatten: Reshape from to
- Transpose: Get for standard attention format
- Attention: Compute attention over the tokens
- Reshape: Back to
Multi-Head Attention
Multi-head attention splits the channels into multiple "heads", each computing attention independently:
where each head is:
Number of Heads
Where to Place Attention
Due to the cost, we cannot use attention at every layer. The standard strategy:
| Resolution | Tokens | Attention Cost | Use Attention? |
|---|---|---|---|
| 256x256 | 65,536 | 4.3 billion ops | No - too expensive |
| 128x128 | 16,384 | 268 million ops | Optional |
| 64x64 | 4,096 | 16.8 million ops | Optional |
| 32x32 | 1,024 | 1 million ops | Yes - affordable |
| 16x16 | 256 | 65,536 ops | Yes - cheap and critical |
The common pattern in diffusion U-Nets:
- Bottleneck (lowest resolution): Always use attention here for global reasoning
- 16x16 and 32x32 levels: Use attention for semantic coherence
- 64x64 and above: Skip attention, rely on convolutions for local patterns
Architecture Variants
Complete Attention Block
Here's a production-ready attention block that can be inserted into U-Net:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class AttentionBlock(nn.Module):
6 """
7 Complete attention block for diffusion U-Net.
8
9 Includes:
10 - Pre-normalization (GroupNorm)
11 - Multi-head self-attention
12 - Residual connection
13 """
14
15 def __init__(
16 self,
17 channels: int,
18 num_heads: int = 8,
19 num_groups: int = 32,
20 ):
21 super().__init__()
22
23 self.num_heads = num_heads
24 self.head_dim = channels // num_heads
25
26 # Pre-norm
27 self.norm = nn.GroupNorm(num_groups, channels)
28
29 # QKV and output projections
30 self.qkv = nn.Conv2d(channels, channels * 3, 1, bias=False)
31 self.proj = nn.Conv2d(channels, channels, 1)
32
33 # Initialize output projection to zero for stable training
34 nn.init.zeros_(self.proj.weight)
35
36 self.scale = self.head_dim ** -0.5
37
38 def forward(self, x: torch.Tensor) -> torch.Tensor:
39 B, C, H, W = x.shape
40 n_tokens = H * W
41
42 # Pre-norm
43 h = self.norm(x)
44
45 # QKV projection: [B, 3*C, H, W]
46 qkv = self.qkv(h)
47
48 # Reshape for multi-head attention
49 # [B, 3*C, H, W] -> [B, 3, num_heads, head_dim, H*W]
50 qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, n_tokens)
51
52 # Split Q, K, V
53 q = qkv[:, 0] # [B, heads, dim, tokens]
54 k = qkv[:, 1]
55 v = qkv[:, 2]
56
57 # Compute attention: [B, heads, tokens, tokens]
58 attn = torch.einsum("bhdn,bhdm->bhnm", q, k) * self.scale
59 attn = F.softmax(attn, dim=-1)
60
61 # Apply to values: [B, heads, dim, tokens]
62 out = torch.einsum("bhnm,bhdm->bhdn", attn, v)
63
64 # Reshape: [B, C, H, W]
65 out = out.reshape(B, C, H, W)
66
67 # Output projection
68 out = self.proj(out)
69
70 # Residual
71 return x + out
72
73
74# Usage in U-Net
75class UNetBlockWithAttention(nn.Module):
76 def __init__(self, channels: int, time_emb_dim: int, use_attention: bool = True):
77 super().__init__()
78
79 self.res_block = ResBlock(channels, channels, time_emb_dim)
80
81 if use_attention:
82 self.attn = AttentionBlock(channels)
83 else:
84 self.attn = nn.Identity()
85
86 def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
87 x = self.res_block(x, time_emb)
88 x = self.attn(x)
89 return xZero Initialization
Summary
In this section, we added attention to our diffusion U-Net:
- Why attention: Enables global context that convolutions cannot efficiently capture
- Spatial self-attention: Treat each pixel as a token, compute attention over all positions
- Multi-head attention: Multiple attention patterns learned simultaneously
- Strategic placement: Use attention at low resolutions (16x16, 32x32) to manage computational cost
- Implementation details: Pre-normalization, residual connections, zero initialization