Chapter 8
16 min read
Section 41 of 75

Encoder-Decoder Cross-Attention

Transformer Decoder

Introduction

Cross-attention is the bridge between encoder and decoder. It allows the decoder to "look at" the encoded source sentence when generating each target token. Without cross-attention, the decoder would have no way to know what it's supposed to translate.


Self-Attention vs Cross-Attention

Self-Attention (Same Sequence)

πŸ“text
1Q, K, V all come from the SAME sequence:
2
3Input: "The dog runs"
4Q = Input @ W_q
5K = Input @ W_k  ← Same source
6V = Input @ W_v
7
8Each token attends to all tokens in the same sequence.

Cross-Attention (Different Sequences)

πŸ“text
1Q from decoder, K/V from encoder:
2
3Decoder: "The dog"
4Encoder Memory: "Der Hund lΓ€uft"
5
6Q = Decoder @ W_q       ← From target
7K = Memory @ W_k        ← From source
8V = Memory @ W_v        ← From source
9
10Decoder tokens attend to encoder tokens!

Visual Comparison

πŸ“text
1Self-Attention:                 Cross-Attention:
2
3  The dog runs                   Decoder: The dog
4   ↓   ↓   ↓                          ↓   ↓
5  [Q] [Q] [Q]                        [Q] [Q]
6   ↓   ↓   ↓                          β”‚   β”‚
7  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”                         β”‚   β”‚
8  β”‚ Attentionβ”‚                        ↓   ↓
9  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
10   ↑   ↑   ↑                     β”‚ Attentionβ”‚
11  [K] [K] [K]                    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
12  [V] [V] [V]                     ↑   ↑   ↑
13   ↑   ↑   ↑                     [K] [K] [K]
14  The dog runs                   [V] [V] [V]
15                                  ↑   ↑   ↑
16                                 Der Hund lΓ€uft

Shape Considerations

Different Sequence Lengths

The key insight: source and target can have different lengths!

πŸ“text
1Source (German): "Der schnelle braune Fuchs" (4 tokens)
2Target (English): "The quick brown fox jumps" (5 tokens)
3
4Self-attention shapes:
5  Q, K, V: [batch, tgt_len, d_model] = [B, 5, 512]
6  Scores: [batch, heads, 5, 5]
7
8Cross-attention shapes:
9  Q: [batch, tgt_len, d_model] = [B, 5, 512]  ← from decoder
10  K: [batch, src_len, d_model] = [B, 4, 512]  ← from encoder
11  V: [batch, src_len, d_model] = [B, 4, 512]  ← from encoder
12  Scores: [batch, heads, 5, 4]  ← tgt_len Γ— src_len!

Shape Flow Through Cross-Attention

πŸ“text
1Decoder hidden:    [batch, tgt_len, d_model]
2                      β”‚
3                      β–Ό
4        Q = decoder @ W_q: [batch, tgt_len, d_model]
5        Reshape: [batch, heads, tgt_len, d_k]
6
7Encoder memory:    [batch, src_len, d_model]
8                      β”‚
9        K = memory @ W_k: [batch, src_len, d_model]
10        V = memory @ W_v: [batch, src_len, d_model]
11        Reshape K: [batch, heads, src_len, d_k]
12        Reshape V: [batch, heads, src_len, d_k]
13
14Attention scores:
15        Q @ K^T: [batch, heads, tgt_len, d_k]
16               @ [batch, heads, d_k, src_len]
17               = [batch, heads, tgt_len, src_len]
18
19Output:
20        softmax(scores) @ V:
21        [batch, heads, tgt_len, src_len]
22      @ [batch, heads, src_len, d_k]
23      = [batch, heads, tgt_len, d_k]
24
25Final:     [batch, tgt_len, d_model]

Implementation

MultiHeadAttention with Cross-Attention Support

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional
6
7
8class MultiHeadAttention(nn.Module):
9    """
10    Multi-Head Attention supporting both self-attention and cross-attention.
11
12    For self-attention: query = key = value
13    For cross-attention: query from decoder, key/value from encoder
14
15    Args:
16        d_model: Model dimension
17        num_heads: Number of attention heads
18        dropout: Dropout probability
19
20    Example:
21        >>> mha = MultiHeadAttention(512, 8)
22        >>> # Self-attention
23        >>> out = mha(x, x, x)
24        >>> # Cross-attention
25        >>> out = mha(decoder_state, encoder_memory, encoder_memory)
26    """
27
28    def __init__(
29        self,
30        d_model: int,
31        num_heads: int,
32        dropout: float = 0.1
33    ):
34        super().__init__()
35
36        assert d_model % num_heads == 0
37
38        self.d_model = d_model
39        self.num_heads = num_heads
40        self.d_k = d_model // num_heads
41
42        # Projection layers
43        self.W_q = nn.Linear(d_model, d_model)
44        self.W_k = nn.Linear(d_model, d_model)
45        self.W_v = nn.Linear(d_model, d_model)
46        self.W_o = nn.Linear(d_model, d_model)
47
48        self.dropout = nn.Dropout(dropout)
49
50    def forward(
51        self,
52        query: torch.Tensor,
53        key: torch.Tensor,
54        value: torch.Tensor,
55        mask: Optional[torch.Tensor] = None
56    ) -> torch.Tensor:
57        """
58        Compute multi-head attention.
59
60        Args:
61            query: [batch, query_len, d_model]
62            key: [batch, key_len, d_model]
63            value: [batch, key_len, d_model]
64            mask: [batch, 1, query_len, key_len] or [batch, 1, 1, key_len]
65
66        Returns:
67            output: [batch, query_len, d_model]
68        """
69        batch_size = query.size(0)
70        query_len = query.size(1)
71        key_len = key.size(1)
72
73        # Linear projections
74        Q = self.W_q(query)  # [batch, query_len, d_model]
75        K = self.W_k(key)    # [batch, key_len, d_model]
76        V = self.W_v(value)  # [batch, key_len, d_model]
77
78        # Reshape for multi-head: [batch, heads, seq_len, d_k]
79        Q = Q.view(batch_size, query_len, self.num_heads, self.d_k).transpose(1, 2)
80        K = K.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
81        V = V.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
82
83        # Attention scores: [batch, heads, query_len, key_len]
84        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
85
86        # Apply mask
87        if mask is not None:
88            scores = scores.masked_fill(mask == 0, float('-inf'))
89
90        # Attention weights
91        attn_weights = F.softmax(scores, dim=-1)
92        attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
93        attn_weights = self.dropout(attn_weights)
94
95        # Apply to values: [batch, heads, query_len, d_k]
96        context = torch.matmul(attn_weights, V)
97
98        # Combine heads: [batch, query_len, d_model]
99        context = context.transpose(1, 2).contiguous()
100        context = context.view(batch_size, query_len, self.d_model)
101
102        # Final projection
103        output = self.W_o(context)
104
105        return output
106
107
108# Test both modes
109def test_attention_modes():
110    batch, src_len, tgt_len, d_model, heads = 2, 10, 8, 512, 8
111
112    mha = MultiHeadAttention(d_model, heads)
113
114    # Self-attention (same sequence)
115    x = torch.randn(batch, tgt_len, d_model)
116    self_attn_out = mha(x, x, x)
117    print(f"Self-attention: input {x.shape} β†’ output {self_attn_out.shape}")
118
119    # Cross-attention (different sequences)
120    decoder_state = torch.randn(batch, tgt_len, d_model)
121    encoder_memory = torch.randn(batch, src_len, d_model)
122    cross_attn_out = mha(decoder_state, encoder_memory, encoder_memory)
123    print(f"Cross-attention: Q {decoder_state.shape}, K/V {encoder_memory.shape}")
124    print(f"                 β†’ output {cross_attn_out.shape}")
125
126    # Verify shapes
127    assert self_attn_out.shape == (batch, tgt_len, d_model)
128    assert cross_attn_out.shape == (batch, tgt_len, d_model)
129
130    print("\nβœ“ Both attention modes work correctly!")
131
132
133test_attention_modes()

Cross-Attention Masking

Only Padding Mask Needed

🐍python
1def create_cross_attention_mask(
2    encoder_input_ids: torch.Tensor,
3    decoder_seq_len: int,
4    pad_id: int = 0
5) -> torch.Tensor:
6    """
7    Create mask for cross-attention.
8
9    No causal masking needed - decoder can see entire source.
10    Only mask padding tokens in source.
11
12    Args:
13        encoder_input_ids: [batch, src_len]
14        decoder_seq_len: Length of decoder sequence
15        pad_id: Padding token ID
16
17    Returns:
18        mask: [batch, 1, decoder_seq_len, src_len]
19    """
20    batch_size, src_len = encoder_input_ids.shape
21
22    # Padding mask: [batch, src_len]
23    src_padding_mask = (encoder_input_ids != pad_id)
24
25    # Reshape: [batch, 1, 1, src_len]
26    mask = src_padding_mask.unsqueeze(1).unsqueeze(2)
27
28    # Expand for all decoder positions: [batch, 1, decoder_seq_len, src_len]
29    mask = mask.expand(-1, -1, decoder_seq_len, -1)
30
31    return mask.float()
32
33
34# Test
35def test_cross_mask():
36    encoder_ids = torch.tensor([
37        [10, 20, 30, 40, 0, 0],  # 4 real tokens
38        [10, 20, 30, 0, 0, 0],   # 3 real tokens
39    ])
40    decoder_len = 5
41
42    mask = create_cross_attention_mask(encoder_ids, decoder_len, pad_id=0)
43
44    print("Cross-Attention Mask")
45    print(f"Encoder shape: {encoder_ids.shape}")
46    print(f"Decoder length: {decoder_len}")
47    print(f"Mask shape: {mask.shape}")
48
49    print("\nMask for sentence 1 (4 source tokens):")
50    print(mask[0, 0])
51
52    print("\nMask for sentence 2 (3 source tokens):")
53    print(mask[1, 0])
54
55
56test_cross_mask()

Output:

πŸ“text
1Cross-Attention Mask
2Encoder shape: torch.Size([2, 6])
3Decoder length: 5
4Mask shape: torch.Size([2, 1, 5, 6])
5
6Mask for sentence 1 (4 source tokens):
7tensor([[1., 1., 1., 1., 0., 0.],
8        [1., 1., 1., 1., 0., 0.],
9        [1., 1., 1., 1., 0., 0.],
10        [1., 1., 1., 1., 0., 0.],
11        [1., 1., 1., 1., 0., 0.]])
12
13Mask for sentence 2 (3 source tokens):
14tensor([[1., 1., 1., 0., 0., 0.],
15        [1., 1., 1., 0., 0., 0.],
16        [1., 1., 1., 0., 0., 0.],
17        [1., 1., 1., 0., 0., 0.],
18        [1., 1., 1., 0., 0., 0.]])

Cross-Attention in Practice

How Translation Uses Cross-Attention

🐍python
1def demonstrate_cross_attention():
2    """
3    Show how cross-attention enables translation.
4    """
5    print("Cross-Attention for Translation")
6    print("=" * 50)
7
8    # Simulated scenario
9    source = "Der Hund lΓ€uft"  # German
10    target = "The dog runs"     # English
11
12    print(f"\nSource: '{source}'")
13    print(f"Target: '{target}'")
14
15    print("\nDuring generation of 'dog':")
16    print("  Decoder query: representation of 'The'")
17    print("  Encoder keys: representations of [Der, Hund, lΓ€uft]")
18    print("  Attention weights (idealized):")
19    print("    'Der':   0.1  (article, not directly relevant)")
20    print("    'Hund':  0.8  (dog! high attention)")
21    print("    'lΓ€uft': 0.1  (runs, will be next)")
22    print("  Output: blend of encoder values, weighted by attention")
23    print("  Result: decoder 'sees' mainly 'Hund' β†’ predicts 'dog'")
24
25    print("\nDuring generation of 'runs':")
26    print("  Decoder query: representation of 'dog'")
27    print("  Attention weights (idealized):")
28    print("    'Der':   0.05")
29    print("    'Hund':  0.15 (already translated)")
30    print("    'lΓ€uft': 0.80 (runs! high attention)")
31    print("  Result: decoder 'sees' mainly 'lΓ€uft' β†’ predicts 'runs'")
32
33
34demonstrate_cross_attention()

Encoder Output Caching

Why Cache Encoder Output?

During autoregressive generation:

- Encoder output is constant (source doesn't change)

- Decoder runs multiple times (once per output token)

- No need to recompute K, V for cross-attention

🐍python
1class CachedCrossAttention(nn.Module):
2    """
3    Cross-attention with encoder output caching for efficient inference.
4
5    During generation, encoder K/V are computed once and reused.
6    """
7
8    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
9        super().__init__()
10        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
11
12        # Cache for encoder projections
13        self.cached_k = None
14        self.cached_v = None
15
16    def forward(
17        self,
18        decoder_state: torch.Tensor,
19        encoder_output: torch.Tensor,
20        cross_mask: Optional[torch.Tensor] = None,
21        use_cache: bool = False
22    ) -> torch.Tensor:
23        """
24        Args:
25            decoder_state: [batch, tgt_len, d_model]
26            encoder_output: [batch, src_len, d_model]
27            cross_mask: [batch, 1, tgt_len, src_len]
28            use_cache: Whether to cache encoder projections
29
30        Returns:
31            output: [batch, tgt_len, d_model]
32        """
33        if use_cache and self.cached_k is not None:
34            # Use cached encoder projections
35            # Only need to project decoder query
36            Q = self.attention.W_q(decoder_state)
37            K = self.cached_k
38            V = self.cached_v
39        else:
40            # Project both decoder and encoder
41            Q = self.attention.W_q(decoder_state)
42            K = self.attention.W_k(encoder_output)
43            V = self.attention.W_v(encoder_output)
44
45            if use_cache:
46                self.cached_k = K
47                self.cached_v = V
48
49        # Continue with attention computation
50        # (simplified - actual implementation would use full attention)
51        batch_size = Q.size(0)
52        Q = Q.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
53        K = K.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
54        V = V.view(batch_size, -1, self.attention.num_heads, self.attention.d_k).transpose(1, 2)
55
56        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.attention.d_k)
57        if cross_mask is not None:
58            scores = scores.masked_fill(cross_mask == 0, float('-inf'))
59
60        attn = F.softmax(scores, dim=-1)
61        context = torch.matmul(attn, V)
62        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.attention.d_model)
63
64        return self.attention.W_o(context)
65
66    def clear_cache(self):
67        """Clear cached encoder projections."""
68        self.cached_k = None
69        self.cached_v = None

Summary

Cross-Attention Key Points

AspectSelf-AttentionCross-Attention
Q sourceSame sequenceDecoder
K, V sourceSame sequenceEncoder
Causal maskOptional (decoder only)No
Padding maskYesYes (source padding)
CachingIncrementalEncoder K/V once

Implementation Checklist

- [ ] MultiHeadAttention supports different Q vs K/V sources

- [ ] Handle different sequence lengths (tgt_len β‰  src_len)

- [ ] Apply source padding mask (no causal mask)

- [ ] Cache encoder K/V projections for inference


Exercises

Implementation

1. Add attention weight visualization to cross-attention.

2. Implement relative position bias for cross-attention.

3. Create a "gated" cross-attention that can learn to ignore encoder.

Analysis

4. Visualize cross-attention patterns during translation.

5. What happens if cross-attention mask is all zeros for some position?


Next Section Preview

In the next section, we'll combine masked self-attention, cross-attention, and FFN into the complete TransformerDecoderLayer.