Introduction
Cross-attention is the bridge between encoder and decoder. It allows the decoder to "look at" the encoded source sentence when generating each target token. Without cross-attention, the decoder would have no way to know what it's supposed to translate.
Self-Attention vs Cross-Attention
Self-Attention (Same Sequence)
1Q, K, V all come from the SAME sequence:
2
3Input: "The dog runs"
4Q = Input @ W_q
5K = Input @ W_k β Same source
6V = Input @ W_v
7
8Each token attends to all tokens in the same sequence.Cross-Attention (Different Sequences)
1Q from decoder, K/V from encoder:
2
3Decoder: "The dog"
4Encoder Memory: "Der Hund lΓ€uft"
5
6Q = Decoder @ W_q β From target
7K = Memory @ W_k β From source
8V = Memory @ W_v β From source
9
10Decoder tokens attend to encoder tokens!Visual Comparison
1Self-Attention: Cross-Attention:
2
3 The dog runs Decoder: The dog
4 β β β β β
5 [Q] [Q] [Q] [Q] [Q]
6 β β β β β
7 βββββββββββ β β
8 β Attentionβ β β
9 βββββββββββ βββββββββββ
10 β β β β Attentionβ
11 [K] [K] [K] βββββββββββ
12 [V] [V] [V] β β β
13 β β β [K] [K] [K]
14 The dog runs [V] [V] [V]
15 β β β
16 Der Hund lΓ€uftShape Considerations
Different Sequence Lengths
The key insight: source and target can have different lengths!
1Source (German): "Der schnelle braune Fuchs" (4 tokens)
2Target (English): "The quick brown fox jumps" (5 tokens)
3
4Self-attention shapes:
5 Q, K, V: [batch, tgt_len, d_model] = [B, 5, 512]
6 Scores: [batch, heads, 5, 5]
7
8Cross-attention shapes:
9 Q: [batch, tgt_len, d_model] = [B, 5, 512] β from decoder
10 K: [batch, src_len, d_model] = [B, 4, 512] β from encoder
11 V: [batch, src_len, d_model] = [B, 4, 512] β from encoder
12 Scores: [batch, heads, 5, 4] β tgt_len Γ src_len!Shape Flow Through Cross-Attention
1Decoder hidden: [batch, tgt_len, d_model]
2 β
3 βΌ
4 Q = decoder @ W_q: [batch, tgt_len, d_model]
5 Reshape: [batch, heads, tgt_len, d_k]
6
7Encoder memory: [batch, src_len, d_model]
8 β
9 K = memory @ W_k: [batch, src_len, d_model]
10 V = memory @ W_v: [batch, src_len, d_model]
11 Reshape K: [batch, heads, src_len, d_k]
12 Reshape V: [batch, heads, src_len, d_k]
13
14Attention scores:
15 Q @ K^T: [batch, heads, tgt_len, d_k]
16 @ [batch, heads, d_k, src_len]
17 = [batch, heads, tgt_len, src_len]
18
19Output:
20 softmax(scores) @ V:
21 [batch, heads, tgt_len, src_len]
22 @ [batch, heads, src_len, d_k]
23 = [batch, heads, tgt_len, d_k]
24
25Final: [batch, tgt_len, d_model]Implementation
MultiHeadAttention with Cross-Attention Support
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional
6
7
8class MultiHeadAttention(nn.Module):
9 """
10 Multi-Head Attention supporting both self-attention and cross-attention.
11
12 For self-attention: query = key = value
13 For cross-attention: query from decoder, key/value from encoder
14
15 Args:
16 d_model: Model dimension
17 num_heads: Number of attention heads
18 dropout: Dropout probability
19
20 Example:
21 >>> mha = MultiHeadAttention(512, 8)
22 >>> # Self-attention
23 >>> out = mha(x, x, x)
24 >>> # Cross-attention
25 >>> out = mha(decoder_state, encoder_memory, encoder_memory)
26 """
27
28 def __init__(
29 self,
30 d_model: int,
31 num_heads: int,
32 dropout: float = 0.1
33 ):
34 super().__init__()
35
36 assert d_model % num_heads == 0
37
38 self.d_model = d_model
39 self.num_heads = num_heads
40 self.d_k = d_model // num_heads
41
42 # Projection layers
43 self.W_q = nn.Linear(d_model, d_model)
44 self.W_k = nn.Linear(d_model, d_model)
45 self.W_v = nn.Linear(d_model, d_model)
46 self.W_o = nn.Linear(d_model, d_model)
47
48 self.dropout = nn.Dropout(dropout)
49
50 def forward(
51 self,
52 query: torch.Tensor,
53 key: torch.Tensor,
54 value: torch.Tensor,
55 mask: Optional[torch.Tensor] = None
56 ) -> torch.Tensor:
57 """
58 Compute multi-head attention.
59
60 Args:
61 query: [batch, query_len, d_model]
62 key: [batch, key_len, d_model]
63 value: [batch, key_len, d_model]
64 mask: [batch, 1, query_len, key_len] or [batch, 1, 1, key_len]
65
66 Returns:
67 output: [batch, query_len, d_model]
68 """
69 batch_size = query.size(0)
70 query_len = query.size(1)
71 key_len = key.size(1)
72
73 # Linear projections
74 Q = self.W_q(query) # [batch, query_len, d_model]
75 K = self.W_k(key) # [batch, key_len, d_model]
76 V = self.W_v(value) # [batch, key_len, d_model]
77
78 # Reshape for multi-head: [batch, heads, seq_len, d_k]
79 Q = Q.view(batch_size, query_len, self.num_heads, self.d_k).transpose(1, 2)
80 K = K.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
81 V = V.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
82
83 # Attention scores: [batch, heads, query_len, key_len]
84 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
85
86 # Apply mask
87 if mask is not None:
88 scores = scores.masked_fill(mask == 0, float('-inf'))
89
90 # Attention weights
91 attn_weights = F.softmax(scores, dim=-1)
92 attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
93 attn_weights = self.dropout(attn_weights)
94
95 # Apply to values: [batch, heads, query_len, d_k]
96 context = torch.matmul(attn_weights, V)
97
98 # Combine heads: [batch, query_len, d_model]
99 context = context.transpose(1, 2).contiguous()
100 context = context.view(batch_size, query_len, self.d_model)
101
102 # Final projection
103 output = self.W_o(context)
104
105 return output
106
107
108# Test both modes
109def test_attention_modes():
110 batch, src_len, tgt_len, d_model, heads = 2, 10, 8, 512, 8
111
112 mha = MultiHeadAttention(d_model, heads)
113
114 # Self-attention (same sequence)
115 x = torch.randn(batch, tgt_len, d_model)
116 self_attn_out = mha(x, x, x)
117 print(f"Self-attention: input {x.shape} β output {self_attn_out.shape}")
118
119 # Cross-attention (different sequences)
120 decoder_state = torch.randn(batch, tgt_len, d_model)
121 encoder_memory = torch.randn(batch, src_len, d_model)
122 cross_attn_out = mha(decoder_state, encoder_memory, encoder_memory)
123 print(f"Cross-attention: Q {decoder_state.shape}, K/V {encoder_memory.shape}")
124 print(f" β output {cross_attn_out.shape}")
125
126 # Verify shapes
127 assert self_attn_out.shape == (batch, tgt_len, d_model)
128 assert cross_attn_out.shape == (batch, tgt_len, d_model)
129
130 print("\nβ Both attention modes work correctly!")
131
132
133test_attention_modes()Cross-Attention Masking
Only Padding Mask Needed
1def create_cross_attention_mask(
2 encoder_input_ids: torch.Tensor,
3 decoder_seq_len: int,
4 pad_id: int = 0
5) -> torch.Tensor:
6 """
7 Create mask for cross-attention.
8
9 No causal masking needed - decoder can see entire source.
10 Only mask padding tokens in source.
11
12 Args:
13 encoder_input_ids: [batch, src_len]
14 decoder_seq_len: Length of decoder sequence
15 pad_id: Padding token ID
16
17 Returns:
18 mask: [batch, 1, decoder_seq_len, src_len]
19 """
20 batch_size, src_len = encoder_input_ids.shape
21
22 # Padding mask: [batch, src_len]
23 src_padding_mask = (encoder_input_ids != pad_id)
24
25 # Reshape: [batch, 1, 1, src_len]
26 mask = src_padding_mask.unsqueeze(1).unsqueeze(2)
27
28 # Expand for all decoder positions: [batch, 1, decoder_seq_len, src_len]
29 mask = mask.expand(-1, -1, decoder_seq_len, -1)
30
31 return mask.float()
32
33
34# Test
35def test_cross_mask():
36 encoder_ids = torch.tensor([
37 [10, 20, 30, 40, 0, 0], # 4 real tokens
38 [10, 20, 30, 0, 0, 0], # 3 real tokens
39 ])
40 decoder_len = 5
41
42 mask = create_cross_attention_mask(encoder_ids, decoder_len, pad_id=0)
43
44 print("Cross-Attention Mask")
45 print(f"Encoder shape: {encoder_ids.shape}")
46 print(f"Decoder length: {decoder_len}")
47 print(f"Mask shape: {mask.shape}")
48
49 print("\nMask for sentence 1 (4 source tokens):")
50 print(mask[0, 0])
51
52 print("\nMask for sentence 2 (3 source tokens):")
53 print(mask[1, 0])
54
55
56test_cross_mask()Output:
1Cross-Attention Mask
2Encoder shape: torch.Size([2, 6])
3Decoder length: 5
4Mask shape: torch.Size([2, 1, 5, 6])
5
6Mask for sentence 1 (4 source tokens):
7tensor([[1., 1., 1., 1., 0., 0.],
8 [1., 1., 1., 1., 0., 0.],
9 [1., 1., 1., 1., 0., 0.],
10 [1., 1., 1., 1., 0., 0.],
11 [1., 1., 1., 1., 0., 0.]])
12
13Mask for sentence 2 (3 source tokens):
14tensor([[1., 1., 1., 0., 0., 0.],
15 [1., 1., 1., 0., 0., 0.],
16 [1., 1., 1., 0., 0., 0.],
17 [1., 1., 1., 0., 0., 0.],
18 [1., 1., 1., 0., 0., 0.]])Cross-Attention in Practice
How Translation Uses Cross-Attention
1def demonstrate_cross_attention():
2 """
3 Show how cross-attention enables translation.
4 """
5 print("Cross-Attention for Translation")
6 print("=" * 50)
7
8 # Simulated scenario
9 source = "Der Hund lΓ€uft" # German
10 target = "The dog runs" # English
11
12 print(f"\nSource: '{source}'")
13 print(f"Target: '{target}'")
14
15 print("\nDuring generation of 'dog':")
16 print(" Decoder query: representation of 'The'")
17 print(" Encoder keys: representations of [Der, Hund, lΓ€uft]")
18 print(" Attention weights (idealized):")
19 print(" 'Der': 0.1 (article, not directly relevant)")
20 print(" 'Hund': 0.8 (dog! high attention)")
21 print(" 'lΓ€uft': 0.1 (runs, will be next)")
22 print(" Output: blend of encoder values, weighted by attention")
23 print(" Result: decoder 'sees' mainly 'Hund' β predicts 'dog'")
24
25 print("\nDuring generation of 'runs':")
26 print(" Decoder query: representation of 'dog'")
27 print(" Attention weights (idealized):")
28 print(" 'Der': 0.05")
29 print(" 'Hund': 0.15 (already translated)")
30 print(" 'lΓ€uft': 0.80 (runs! high attention)")
31 print(" Result: decoder 'sees' mainly 'lΓ€uft' β predicts 'runs'")
32
33
34demonstrate_cross_attention()Encoder Output Caching
Why Cache Encoder Output?
During autoregressive generation:
- Encoder output is constant (source doesn't change)
- Decoder runs multiple times (once per output token)
- No need to recompute K, V for cross-attention
1class CachedCrossAttention(nn.Module):
2 """
3 Cross-attention with encoder output caching for efficient inference.
4
5 During generation, encoder K/V are computed once and reused.
6 """
7
8 def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
9 super().__init__()
10 self.attention = MultiHeadAttention(d_model, num_heads, dropout)
11
12 # Cache for encoder projections
13 self.cached_k = None
14 self.cached_v = None
15
16 def forward(
17 self,
18 decoder_state: torch.Tensor,
19 encoder_output: torch.Tensor,
20 cross_mask: Optional[torch.Tensor] = None,
21 use_cache: bool = False
22 ) -> torch.Tensor:
23 """
24 Args:
25 decoder_state: [batch, tgt_len, d_model]
26 encoder_output: [batch, src_len, d_model]
27 cross_mask: [batch, 1, tgt_len, src_len]
28 use_cache: Whether to cache encoder projections
29
30 Returns:
31 output: [batch, tgt_len, d_model]
32 """
33 if use_cache and self.cached_k is not None:
34 # Use cached encoder projections
35 # Only need to project decoder query
36 Q = self.attention.W_q(decoder_state)
37 K = self.cached_k
38 V = self.cached_v
39 else:
40 # Project both decoder and encoder
41 Q = self.attention.W_q(decoder_state)
42 K = self.attention.W_k(encoder_output)
43 V = self.attention.W_v(encoder_output)
44
45 if use_cache:
46 self.cached_k = K
47 self.cached_v = V
48
49 # Continue with attention computation
50 # (simplified - actual implementation would use full attention)
51 batch_size = Q.size(0)
52 Q = Q.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
53 K = K.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
54 V = V.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
55
56 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.attention.d_k)
57 if cross_mask is not None:
58 scores = scores.masked_fill(cross_mask == 0, float('-inf'))
59
60 attn = F.softmax(scores, dim=-1)
61 context = torch.matmul(attn, V)
62 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.attention.d_model)
63
64 return self.attention.W_o(context)
65
66 def clear_cache(self):
67 """Clear cached encoder projections."""
68 self.cached_k = None
69 self.cached_v = NoneSummary
Cross-Attention Key Points
| Aspect | Self-Attention | Cross-Attention |
|---|---|---|
| Q source | Same sequence | Decoder |
| K, V source | Same sequence | Encoder |
| Causal mask | Optional (decoder only) | No |
| Padding mask | Yes | Yes (source padding) |
| Caching | Incremental | Encoder K/V once |
Implementation Checklist
- [ ] MultiHeadAttention supports different Q vs K/V sources
- [ ] Handle different sequence lengths (tgt_len β src_len)
- [ ] Apply source padding mask (no causal mask)
- [ ] Cache encoder K/V projections for inference
Exercises
Implementation
1. Add attention weight visualization to cross-attention.
2. Implement relative position bias for cross-attention.
3. Create a "gated" cross-attention that can learn to ignore encoder.
Analysis
4. Visualize cross-attention patterns during translation.
5. What happens if cross-attention mask is all zeros for some position?
Next Section Preview
In the next section, we'll combine masked self-attention, cross-attention, and FFN into the complete TransformerDecoderLayer.