Learning Objectives
By the end of this section, you will be able to:
- Explain why cross-attention is essential for text-to-image generation
- Derive the cross-attention formula and understand the Q/K/V asymmetry
- Interpret attention maps showing word-region alignment
- Implement cross-attention in PyTorch from scratch
- Integrate cross-attention into U-Net transformer blocks
Why Cross-Attention?
In the previous section, we saw how text is encoded into embeddings. But how do these embeddings actually guide the image generation? The answer is cross-attention - a mechanism that allows each part of the image to "look at" and selectively attend to relevant words in the prompt.
The Core Idea: Cross-attention creates a dynamic, learned mapping between spatial image positions and text tokens. A pixel generating a "red car" should attend strongly to the words "red" and "car" in the prompt.
Why Not Just Concatenate?
A simpler approach might be to concatenate text embeddings to image features. However, this has limitations:
| Approach | Pros | Cons |
|---|---|---|
| Concatenation | Simple, fast | No spatial alignment, fixed relationship |
| Global conditioning (AdaGN) | Efficient, works well | Only global style, no word-region mapping |
| Cross-attention | Spatial alignment, interpretable | More compute, O(H*W * L) |
Cross-attention provides spatial alignment - different image regions can attend to different words. When generating "a cat on the left and a dog on the right," left pixels attend to "cat" and right pixels attend to "dog."
Cross-Attention Mathematics
Cross-attention is a variant of the standard attention mechanism where queries come from one sequence and keys/values come from another.
Standard Attention Review
Recall the attention formula:
In self-attention, Q, K, and V all come from the same sequence. In cross-attention, they come from different sources:
Cross-Attention Setup
- Queries (Q): From image features where (spatial positions)
- Keys (K): From text embeddings where is sequence length
- Values (V): Also from text embeddings
With learned projection matrices:
The attention output is:
Dimensionality
Q from Image, K/V from Text
The asymmetry in cross-attention is crucial to understand:
Why Queries from Image?
Each image position "asks a question" by generating a query vector. The question is essentially: "What text information is relevant to me?"
- A position generating a dog asks: "Which words describe me?" and attends to "fluffy," "golden retriever"
- A position generating sky asks: "What should I look like?" and attends to "sunset," "orange," "clouds"
Why Keys/Values from Text?
Text tokens provide keys for matching (what they represent) and values for content (what information they carry):
- Keys: Enable relevance computation. "Car" has a key that matches well with queries from car-like regions
- Values: Carry semantic information to inject. The value for "red" carries color information
Intuition: Image positions are asking questions (queries), text tokens are offering answers (values), and the matching is determined by key-query similarity.
Understanding Attention Maps
The attention weights form interpretable maps showing which image regions attend to which words:
Interpreting Attention Maps
For each text token, we can visualize its attention pattern as a heatmap over the image:
- Object words ("cat," "car") typically show localized attention on the corresponding object region
- Attribute words ("red," "fluffy") attend to regions where that attribute applies
- Spatial words ("left," "background") show attention in the corresponding image areas
- Style words ("impressionist," "photorealistic") often show diffuse, global attention
Attention Map Applications
| Application | How It Uses Attention Maps |
|---|---|
| Prompt debugging | Visualize which words affect which regions |
| Attention editing | Modify maps to change object placement |
| Prompt weighting | Strengthen/weaken specific word attention |
| Layout control | Force specific words to attend to specific regions |
PyTorch Implementation
Let's implement cross-attention from scratch:
Usage Example
1# Example dimensions (Stable Diffusion-like)
2batch_size = 2
3height, width = 64, 64 # Latent spatial size
4query_dim = 320 # U-Net channel dimension
5context_dim = 768 # CLIP text embedding dimension
6seq_len = 77 # CLIP max sequence length
7
8# Create module
9cross_attn = CrossAttention(
10 query_dim=query_dim,
11 context_dim=context_dim,
12 n_heads=8,
13 head_dim=40,
14)
15
16# Image features (flattened spatial)
17x = torch.randn(batch_size, height * width, query_dim)
18
19# Text embeddings
20context = torch.randn(batch_size, seq_len, context_dim)
21
22# Apply cross-attention
23out = cross_attn(x, context)
24print(f"Input shape: {x.shape}")
25print(f"Output shape: {out.shape}")
26# Input shape: torch.Size([2, 4096, 320])
27# Output shape: torch.Size([2, 4096, 320])Multi-Head Cross-Attention
Like self-attention, cross-attention benefits from multiple heads:
Why Multiple Heads?
Each attention head can learn different types of text-image relationships:
- Head 1: Object-noun alignment ("dog" -> dog pixels)
- Head 2: Attribute-region alignment ("red" -> colored regions)
- Head 3: Spatial relationships ("above" -> upper regions)
- Head 4: Style/texture patterns ("oil painting" -> all regions)
The outputs of all heads are concatenated and projected:
Head Dimension Trade-off
Integration with U-Net
Cross-attention is integrated into U-Net via transformer blocks placed at specific resolution levels:
Typical U-Net Integration
1class TextConditionedUNet(nn.Module):
2 """U-Net with cross-attention for text conditioning."""
3
4 def __init__(self, ...):
5 # ... encoder and decoder convolutions ...
6
7 # Add transformer blocks at lower resolutions
8 # where cross-attention is computationally feasible
9 self.transformer_blocks = nn.ModuleDict({
10 "down_32": TransformerBlock(dim=320, context_dim=768),
11 "down_16": TransformerBlock(dim=640, context_dim=768),
12 "down_8": TransformerBlock(dim=1280, context_dim=768),
13 "mid_8": TransformerBlock(dim=1280, context_dim=768),
14 "up_8": TransformerBlock(dim=1280, context_dim=768),
15 "up_16": TransformerBlock(dim=640, context_dim=768),
16 "up_32": TransformerBlock(dim=320, context_dim=768),
17 })
18
19 def forward(self, x, t, context):
20 """
21 Args:
22 x: Noisy image [B, C, H, W]
23 t: Timestep [B]
24 context: Text embeddings [B, L, context_dim]
25 """
26 # ... time embedding, downsampling ...
27
28 # At 32x32 resolution:
29 h = self.down_conv_32(h)
30 h = rearrange(h, 'b c h w -> b (h w) c') # Flatten spatial
31 h = self.transformer_blocks["down_32"](h, context)
32 h = rearrange(h, 'b (h w) c -> b c h w', h=32, w=32)
33
34 # ... continue through U-Net ...Where to Add Cross-Attention?
| Resolution | Cross-Attention? | Reason |
|---|---|---|
| 64x64 | Optional (expensive) | 4096 positions, high memory |
| 32x32 | Yes | Good balance of detail and efficiency |
| 16x16 | Yes | Important for semantic content |
| 8x8 | Yes | Captures high-level composition |
Lower resolutions have fewer spatial positions, making cross-attention more efficient. The semantic information from text is most useful at these levels where high-level structure is determined.
Key Takeaways
- Cross-attention enables spatial text-image alignment - each image position can attend to relevant words
- Q from image, K/V from text - image positions "ask" which text tokens are relevant
- Attention maps are interpretable - visualize which words affect which image regions
- Multi-head attention learns diverse relationships (objects, attributes, spatial, style)
- Integrated in transformer blocks with self-attention and feed-forward layers
- Applied at lower resolutions (8x8, 16x16, 32x32) for efficiency
Looking Ahead: Cross-attention requires good text embeddings to work well. In the next section, we'll explore CLIP - the contrastive model that learns to align text and image representations, making it ideal for text-to-image conditioning.