Chapter 9
20 min read
Section 45 of 76

Cross-Attention Mechanism

Text-to-Image Foundations

Learning Objectives

By the end of this section, you will be able to:

  1. Explain why cross-attention is essential for text-to-image generation
  2. Derive the cross-attention formula and understand the Q/K/V asymmetry
  3. Interpret attention maps showing word-region alignment
  4. Implement cross-attention in PyTorch from scratch
  5. Integrate cross-attention into U-Net transformer blocks

Why Cross-Attention?

In the previous section, we saw how text is encoded into embeddings. But how do these embeddings actually guide the image generation? The answer is cross-attention - a mechanism that allows each part of the image to "look at" and selectively attend to relevant words in the prompt.

The Core Idea: Cross-attention creates a dynamic, learned mapping between spatial image positions and text tokens. A pixel generating a "red car" should attend strongly to the words "red" and "car" in the prompt.

Why Not Just Concatenate?

A simpler approach might be to concatenate text embeddings to image features. However, this has limitations:

ApproachProsCons
ConcatenationSimple, fastNo spatial alignment, fixed relationship
Global conditioning (AdaGN)Efficient, works wellOnly global style, no word-region mapping
Cross-attentionSpatial alignment, interpretableMore compute, O(H*W * L)

Cross-attention provides spatial alignment - different image regions can attend to different words. When generating "a cat on the left and a dog on the right," left pixels attend to "cat" and right pixels attend to "dog."


Cross-Attention Mathematics

Cross-attention is a variant of the standard attention mechanism where queries come from one sequence and keys/values come from another.

Standard Attention Review

Recall the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

In self-attention, Q, K, and V all come from the same sequence. In cross-attention, they come from different sources:

Cross-Attention Setup

  • Queries (Q): From image features xRN×dimg\mathbf{x} \in \mathbb{R}^{N \times d_{\text{img}}}where N=H×WN = H \times W (spatial positions)
  • Keys (K): From text embeddings cRL×dtext\mathbf{c} \in \mathbb{R}^{L \times d_{\text{text}}}where LL is sequence length
  • Values (V): Also from text embeddings

With learned projection matrices:

Q=xWQ,K=cWK,V=cWVQ = \mathbf{x} W^Q, \quad K = \mathbf{c} W^K, \quad V = \mathbf{c} W^V

The attention output is:

CrossAttn(x,c)=softmax((xWQ)(cWK)Tdk)(cWV)\text{CrossAttn}(\mathbf{x}, \mathbf{c}) = \text{softmax}\left(\frac{(\mathbf{x} W^Q)(\mathbf{c} W^K)^T}{\sqrt{d_k}}\right)(\mathbf{c} W^V)

Dimensionality

The attention matrix has shape [N,L][N, L] - each of the NN image positions has attention weights over LL text tokens. For a 64x64 feature map and 77 text tokens, this is a 4096 x 77 matrix.

Q from Image, K/V from Text

The asymmetry in cross-attention is crucial to understand:

Why Queries from Image?

Each image position "asks a question" by generating a query vector. The question is essentially: "What text information is relevant to me?"

  • A position generating a dog asks: "Which words describe me?" and attends to "fluffy," "golden retriever"
  • A position generating sky asks: "What should I look like?" and attends to "sunset," "orange," "clouds"

Why Keys/Values from Text?

Text tokens provide keys for matching (what they represent) and values for content (what information they carry):

  • Keys: Enable relevance computation. "Car" has a key that matches well with queries from car-like regions
  • Values: Carry semantic information to inject. The value for "red" carries color information
Intuition: Image positions are asking questions (queries), text tokens are offering answers (values), and the matching is determined by key-query similarity.

Understanding Attention Maps

The attention weights form interpretable maps showing which image regions attend to which words:

Interpreting Attention Maps

For each text token, we can visualize its attention pattern as a heatmap over the image:

  • Object words ("cat," "car") typically show localized attention on the corresponding object region
  • Attribute words ("red," "fluffy") attend to regions where that attribute applies
  • Spatial words ("left," "background") show attention in the corresponding image areas
  • Style words ("impressionist," "photorealistic") often show diffuse, global attention

Attention Map Applications

ApplicationHow It Uses Attention Maps
Prompt debuggingVisualize which words affect which regions
Attention editingModify maps to change object placement
Prompt weightingStrengthen/weaken specific word attention
Layout controlForce specific words to attend to specific regions

PyTorch Implementation

Let's implement cross-attention from scratch:

🐍python
1Imports

Standard PyTorch modules for neural network implementation.

6Class Definition

CrossAttention module where image features attend to text embeddings. This is the key mechanism for text-to-image alignment.

9Dimension Parameters

query_dim is the image feature dimension, context_dim is the text embedding dimension. These can be different (e.g., 320 for image, 768 for CLIP text).

19Query Projection

Projects image features to queries. Each spatial position in the image becomes a query that will 'ask' what text tokens are relevant.

22Key-Value Projections

Projects text embeddings to keys and values. Keys determine relevance (attention scores), values carry the information to inject.

26Output Projection

Projects the attention output back to the original image feature dimension, with optional dropout for regularization.

31Scaling Factor

Scale factor 1/sqrt(d_k) prevents attention scores from becoming too large, which would make softmax too peaked.

33Forward Method Signature

x contains image features (flattened spatial dimensions), context contains text token embeddings. The output has the same shape as x.

47Projections

Apply linear projections: Q from image, K and V from text. This is the key asymmetry of cross-attention.

52Multi-Head Reshape

Reshape tensors for multi-head attention. Each head can learn different text-image relationships (e.g., one head for objects, another for style).

64Attention Computation

Matrix multiply Q with K^T to get attention scores [H*W, L]. Each image position gets a score for each text token. Softmax normalizes across text tokens.

70Value Aggregation

Multiply attention weights by values. High-attention text tokens contribute more to each image position's output.

76Output Reshape

Reshape from multi-head format back to [B, H*W, inner_dim], then project to original dimension.

69 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class CrossAttention(nn.Module):
6    """Cross-attention: image features attend to text embeddings."""
7
8    def __init__(
9        self,
10        query_dim: int,      # Dimension of image features
11        context_dim: int,    # Dimension of text embeddings
12        n_heads: int = 8,
13        head_dim: int = 64,
14        dropout: float = 0.0,
15    ):
16        super().__init__()
17        self.n_heads = n_heads
18        self.head_dim = head_dim
19        inner_dim = n_heads * head_dim
20
21        # Q projection from image features
22        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
23
24        # K, V projections from text embeddings
25        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
26        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
27
28        # Output projection
29        self.to_out = nn.Sequential(
30            nn.Linear(inner_dim, query_dim),
31            nn.Dropout(dropout)
32        )
33
34        self.scale = head_dim ** -0.5
35
36    def forward(
37        self,
38        x: torch.Tensor,        # Image features: [B, H*W, query_dim]
39        context: torch.Tensor,  # Text embeddings: [B, L, context_dim]
40    ) -> torch.Tensor:
41        """
42        Args:
43            x: Image features with shape [batch, spatial_tokens, query_dim]
44            context: Text embeddings with shape [batch, seq_len, context_dim]
45        Returns:
46            Output with same shape as x
47        """
48        batch_size, seq_len_q, _ = x.shape
49        seq_len_kv = context.shape[1]
50
51        # Project to Q, K, V
52        q = self.to_q(x)      # [B, H*W, inner_dim]
53        k = self.to_k(context)  # [B, L, inner_dim]
54        v = self.to_v(context)  # [B, L, inner_dim]
55
56        # Reshape for multi-head attention
57        # [B, seq, inner_dim] -> [B, n_heads, seq, head_dim]
58        q = q.view(batch_size, seq_len_q, self.n_heads, self.head_dim)
59        q = q.transpose(1, 2)  # [B, n_heads, H*W, head_dim]
60
61        k = k.view(batch_size, seq_len_kv, self.n_heads, self.head_dim)
62        k = k.transpose(1, 2)  # [B, n_heads, L, head_dim]
63
64        v = v.view(batch_size, seq_len_kv, self.n_heads, self.head_dim)
65        v = v.transpose(1, 2)  # [B, n_heads, L, head_dim]
66
67        # Compute attention scores
68        # [B, n_heads, H*W, head_dim] @ [B, n_heads, head_dim, L]
69        # -> [B, n_heads, H*W, L]
70        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
71        attn_weights = F.softmax(attn_scores, dim=-1)
72
73        # Apply attention to values
74        # [B, n_heads, H*W, L] @ [B, n_heads, L, head_dim]
75        # -> [B, n_heads, H*W, head_dim]
76        out = torch.matmul(attn_weights, v)
77
78        # Reshape back
79        out = out.transpose(1, 2)  # [B, H*W, n_heads, head_dim]
80        out = out.reshape(batch_size, seq_len_q, -1)  # [B, H*W, inner_dim]
81
82        return self.to_out(out)

Usage Example

🐍python
1# Example dimensions (Stable Diffusion-like)
2batch_size = 2
3height, width = 64, 64  # Latent spatial size
4query_dim = 320         # U-Net channel dimension
5context_dim = 768       # CLIP text embedding dimension
6seq_len = 77            # CLIP max sequence length
7
8# Create module
9cross_attn = CrossAttention(
10    query_dim=query_dim,
11    context_dim=context_dim,
12    n_heads=8,
13    head_dim=40,
14)
15
16# Image features (flattened spatial)
17x = torch.randn(batch_size, height * width, query_dim)
18
19# Text embeddings
20context = torch.randn(batch_size, seq_len, context_dim)
21
22# Apply cross-attention
23out = cross_attn(x, context)
24print(f"Input shape: {x.shape}")
25print(f"Output shape: {out.shape}")
26# Input shape: torch.Size([2, 4096, 320])
27# Output shape: torch.Size([2, 4096, 320])

Multi-Head Cross-Attention

Like self-attention, cross-attention benefits from multiple heads:

Why Multiple Heads?

Each attention head can learn different types of text-image relationships:

  • Head 1: Object-noun alignment ("dog" -> dog pixels)
  • Head 2: Attribute-region alignment ("red" -> colored regions)
  • Head 3: Spatial relationships ("above" -> upper regions)
  • Head 4: Style/texture patterns ("oil painting" -> all regions)

The outputs of all heads are concatenated and projected:

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O

Head Dimension Trade-off

With fixed total dimension, more heads means smaller head dimension. Stable Diffusion uses 8 heads with head_dim=40 (320 total). More heads allow more diverse relationships but each head has less capacity.

Integration with U-Net

Cross-attention is integrated into U-Net via transformer blocks placed at specific resolution levels:

🐍python
1TransformerBlock

A complete transformer block combining self-attention, cross-attention, and feed-forward layers. This is the building block for text-conditioned U-Net.

13Self-Attention

Image features attend to themselves, allowing spatial reasoning (e.g., understanding that a hat should be above a head).

17Cross-Attention

Image features attend to text embeddings. This is where text conditioning enters the generation process.

21Feed-Forward Network

MLP with GELU activation that processes each position independently. Expands then contracts (dim -> 4*dim -> dim).

31Forward Method

Pre-norm architecture: normalize then apply each layer with residual connections. Image learns spatial relationships through self-attention and text information through cross-attention.

37Cross-Attention Residual

Normalize, apply cross-attention with text context, add residual. This injects text information into image features.

40Feed-Forward Residual

Final processing with feed-forward network. Allows the model to further transform the text-conditioned features.

40 lines without explanation
1class TransformerBlock(nn.Module):
2    """Transformer block with self-attention and cross-attention."""
3
4    def __init__(
5        self,
6        dim: int,
7        n_heads: int = 8,
8        head_dim: int = 64,
9        context_dim: int = 768,
10        dropout: float = 0.0,
11    ):
12        super().__init__()
13
14        # Self-attention on image features
15        self.self_attn = SelfAttention(dim, n_heads, head_dim, dropout)
16        self.norm1 = nn.LayerNorm(dim)
17
18        # Cross-attention to text
19        self.cross_attn = CrossAttention(dim, context_dim, n_heads, head_dim, dropout)
20        self.norm2 = nn.LayerNorm(dim)
21
22        # Feed-forward network
23        self.ff = nn.Sequential(
24            nn.Linear(dim, dim * 4),
25            nn.GELU(),
26            nn.Dropout(dropout),
27            nn.Linear(dim * 4, dim),
28            nn.Dropout(dropout),
29        )
30        self.norm3 = nn.LayerNorm(dim)
31
32    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
33        """
34        Args:
35            x: Image features [B, H*W, dim]
36            context: Text embeddings [B, L, context_dim]
37        """
38        # Self-attention (image attends to itself)
39        x = x + self.self_attn(self.norm1(x))
40
41        # Cross-attention (image attends to text)
42        x = x + self.cross_attn(self.norm2(x), context)
43
44        # Feed-forward
45        x = x + self.ff(self.norm3(x))
46
47        return x

Typical U-Net Integration

🐍python
1class TextConditionedUNet(nn.Module):
2    """U-Net with cross-attention for text conditioning."""
3
4    def __init__(self, ...):
5        # ... encoder and decoder convolutions ...
6
7        # Add transformer blocks at lower resolutions
8        # where cross-attention is computationally feasible
9        self.transformer_blocks = nn.ModuleDict({
10            "down_32": TransformerBlock(dim=320, context_dim=768),
11            "down_16": TransformerBlock(dim=640, context_dim=768),
12            "down_8":  TransformerBlock(dim=1280, context_dim=768),
13            "mid_8":   TransformerBlock(dim=1280, context_dim=768),
14            "up_8":    TransformerBlock(dim=1280, context_dim=768),
15            "up_16":   TransformerBlock(dim=640, context_dim=768),
16            "up_32":   TransformerBlock(dim=320, context_dim=768),
17        })
18
19    def forward(self, x, t, context):
20        """
21        Args:
22            x: Noisy image [B, C, H, W]
23            t: Timestep [B]
24            context: Text embeddings [B, L, context_dim]
25        """
26        # ... time embedding, downsampling ...
27
28        # At 32x32 resolution:
29        h = self.down_conv_32(h)
30        h = rearrange(h, 'b c h w -> b (h w) c')  # Flatten spatial
31        h = self.transformer_blocks["down_32"](h, context)
32        h = rearrange(h, 'b (h w) c -> b c h w', h=32, w=32)
33
34        # ... continue through U-Net ...

Where to Add Cross-Attention?

ResolutionCross-Attention?Reason
64x64Optional (expensive)4096 positions, high memory
32x32YesGood balance of detail and efficiency
16x16YesImportant for semantic content
8x8YesCaptures high-level composition

Lower resolutions have fewer spatial positions, making cross-attention more efficient. The semantic information from text is most useful at these levels where high-level structure is determined.


Key Takeaways

  1. Cross-attention enables spatial text-image alignment - each image position can attend to relevant words
  2. Q from image, K/V from text - image positions "ask" which text tokens are relevant
  3. Attention maps are interpretable - visualize which words affect which image regions
  4. Multi-head attention learns diverse relationships (objects, attributes, spatial, style)
  5. Integrated in transformer blocks with self-attention and feed-forward layers
  6. Applied at lower resolutions (8x8, 16x16, 32x32) for efficiency
Looking Ahead: Cross-attention requires good text embeddings to work well. In the next section, we'll explore CLIP - the contrastive model that learns to align text and image representations, making it ideal for text-to-image conditioning.