Chapter 3
12 min read
Section 18 of 75

Self-Attention vs Cross-Attention

Multi-Head Attention

Introduction

Multi-head attention can operate in two modes: self-attention (where query, key, and value come from the same source) and cross-attention (where query comes from one source and key/value from another). Understanding this distinction is crucial for building encoder-decoder transformers.


The Two Modes

Self-Attention

Definition: Query, Key, and Value are all derived from the same input sequence.

Use cases:

  • Encoder self-attention
  • Decoder self-attention (masked)
Self-Attention Example
๐Ÿattention.py
3 lines without explanation
1# Self-attention: Q, K, V from same source
2x = encoder_input  # [batch, src_len, d_model]
3output = attention(query=x, key=x, value=x)

What it captures: Relationships within a sequence

  • How tokens relate to each other
  • Syntactic dependencies (subject-verb)
  • Semantic relationships (coreference)

Cross-Attention

Definition: Query comes from one sequence, Key and Value from another.

Use cases:

  • Encoder-decoder attention (decoder queries, encoder provides K/V)
  • Vision-language models (text queries, image provides K/V)
Cross-Attention Example
๐Ÿattention.py
8 lines without explanation
1# Cross-attention: Q from decoder, K/V from encoder
2decoder_state = ...    # [batch, tgt_len, d_model]
3encoder_output = ...   # [batch, src_len, d_model]
4output = attention(
5    query=decoder_state,
6    key=encoder_output,
7    value=encoder_output
8)

What it captures: Relationships between sequences

  • How decoder tokens relate to encoder tokens
  • Alignment in translation (which source word to translate)
  • Grounding in multimodal tasks

Visual Comparison

Self-Attention Pattern

Self-Attention Pattern
๐Ÿ“text
13 lines without explanation
1Input: "The cat sat on the mat"
2
3Attention matrix (6ร—6):
4          The  cat  sat   on  the  mat
5The      [โ–     ยท    ยท    ยท    ยท    ยท  ]
6cat      [ยท    โ–     ยท    ยท    ยท    ยท  ]
7sat      [ยท    โ–     โ–     ยท    ยท    โ–   ]  โ† "sat" attends to "cat" and "mat"
8on       [ยท    ยท    ยท    โ–     ยท    ยท  ]
9the      [ยท    ยท    ยท    ยท    โ–     ยท  ]
10mat      [ยท    ยท    ยท    ยท    ยท    โ–   ]
11
12Shape: [seq_len_q=6, seq_len_k=6]
13Query and Key from same sequence โ†’ Square matrix

Cross-Attention Pattern

Cross-Attention Pattern
๐Ÿ“text
13 lines without explanation
1Source (German): "Der Hund ist schwarz"
2Target (English): "The dog is black"
3
4Attention matrix (4ร—4):
5          Der  Hund  ist  schwarz
6The      [โ–     ยท     ยท    ยท      ]  โ† "The" attends to "Der"
7dog      [ยท    โ–      ยท    ยท      ]  โ† "dog" attends to "Hund"
8is       [ยท    ยท     โ–     ยท      ]  โ† "is" attends to "ist"
9black    [ยท    ยท     ยท    โ–       ]  โ† "black" attends to "schwarz"
10
11Shape: [seq_len_q=4, seq_len_k=4]
12Query from target, Key from source
13(Could be different lengths!)

Cross-Attention with Different Lengths

Cross-Attention with Different Lengths
๐Ÿ“text
12 lines without explanation
1Source: "Der schwarze Hund" (3 tokens)
2Target: "The black dog runs" (4 tokens)
3
4Attention matrix (4ร—3):
5          Der  schwarze  Hund
6The      [โ–       ยท       ยท   ]
7black    [ยท      โ–        ยท   ]
8dog      [ยท      ยท       โ–    ]
9runs     [ยท      ยท       โ–    ]  โ† No source for "runs", attends to "Hund"
10
11Shape: [seq_len_q=4, seq_len_k=3]
12Rectangular matrix!

Architecture Usage

In the Original Transformer

Transformer Architecture Overview
๐Ÿ“text
8 lines without explanation
1ENCODER (N layers):
2โ”œโ”€โ”€ Self-Attention: Q=K=V=encoder_input
3โ””โ”€โ”€ Feed-Forward
4
5DECODER (N layers):
6โ”œโ”€โ”€ Masked Self-Attention: Q=K=V=decoder_input (causal mask)
7โ”œโ”€โ”€ Cross-Attention: Q=decoder_state, K=V=encoder_output
8โ””โ”€โ”€ Feed-Forward

Data Flow Diagram

Data Flow Diagram
๐Ÿ“text
13 lines without explanation
1ENCODER
2                              โ”‚
3Source tokens โ”€โ”€โ†’ Embedding โ”€โ”€โ†’ Self-Attn โ”€โ”€โ†’ FFN โ”€โ”€โ†’ encoder_output
4                              โ”‚                            โ”‚
5                              โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
6                                                           โ”‚
7                           DECODER                         โ”‚
8                              โ”‚                            โ”‚
9Target tokens โ”€โ”€โ†’ Embedding โ”€โ”€โ†’ Masked      โ”€โ”€โ†’ Cross   โ”€โ”€โ”ดโ”€โ”€โ†’ FFN โ”€โ”€โ†’ Output
10                              โ”‚ Self-Attn      โ”‚ Attn
11                              โ”‚                โ”‚
12                           (Q=K=V)          (Q from decoder,
13                                            K=V from encoder)

Implementation Comparison

Minimal Code Difference

Minimal Code Difference
๐Ÿattention.py
18 lines without explanation
1class MultiHeadAttention(nn.Module):
2    def forward(self, query, key, value, mask=None):
3        # Projections
4        Q = self.W_Q(query)  # From query input
5        K = self.W_K(key)    # From key input
6        V = self.W_V(value)  # From value input
7
8        # Rest is identical...
9        # Split heads, attention, combine heads, output projection
10
11        return output, weights
12
13
14# Self-attention: same input for all
15output = mha(x, x, x)  # Q=K=V=x
16
17# Cross-attention: different inputs
18output = mha(decoder_state, encoder_output, encoder_output)  # Qโ‰ K=V

The same module handles both modesโ€”the only difference is what inputs you pass!

Shape Implications

Shape Implications
๐Ÿattention.py
12 lines without explanation
1# Self-attention
2x = torch.randn(batch, seq_len, d_model)
3output, weights = mha(x, x, x)
4# output: [batch, seq_len, d_model]
5# weights: [batch, heads, seq_len, seq_len]  โ† Square!
6
7# Cross-attention
8query = torch.randn(batch, tgt_len, d_model)   # e.g., 20 tokens
9kv = torch.randn(batch, src_len, d_model)      # e.g., 30 tokens
10output, weights = mha(query, kv, kv)
11# output: [batch, tgt_len, d_model]           # 20 tokens
12# weights: [batch, heads, tgt_len, src_len]   # 20ร—30, rectangular!

Masking Differences

Self-Attention Masks

Encoder self-attention: Only padding mask

Encoder Padding Mask
๐Ÿmasking.py
3 lines without explanation
1# Mask padding tokens in source
2padding_mask = (src != PAD_TOKEN).unsqueeze(1).unsqueeze(2)
3# Shape: [batch, 1, 1, src_len]

Decoder self-attention: Causal + padding mask

Decoder Causal Mask
๐Ÿmasking.py
5 lines without explanation
1# Prevent attending to future tokens
2causal_mask = torch.tril(torch.ones(tgt_len, tgt_len))
3padding_mask = (tgt != PAD_TOKEN).unsqueeze(1).unsqueeze(2)
4combined_mask = causal_mask & padding_mask
5# Shape: [batch, 1, tgt_len, tgt_len]

Cross-Attention Masks

Only mask padding in the source (encoder output):

Cross-Attention Mask
๐Ÿmasking.py
4 lines without explanation
1# Decoder queries can see all non-padding encoder positions
2cross_mask = (src != PAD_TOKEN).unsqueeze(1).unsqueeze(2)
3# Shape: [batch, 1, 1, src_len]
4# Broadcasts to [batch, heads, tgt_len, src_len]

No causal mask neededโ€”decoder can look at entire source!


Flexible Attention Module

Here's a module that explicitly handles both modes:

FlexibleMultiHeadAttention
๐Ÿflexible_attention.py
114 lines without explanation
1class FlexibleMultiHeadAttention(nn.Module):
2    """
3    Multi-head attention supporting both self-attention and cross-attention.
4    """
5
6    def __init__(
7        self,
8        d_model: int,
9        num_heads: int,
10        dropout: float = 0.0,
11        bias: bool = True
12    ):
13        super().__init__()
14
15        self.d_model = d_model
16        self.num_heads = num_heads
17        self.d_k = d_model // num_heads
18        self.scale = math.sqrt(self.d_k)
19
20        # Query projection (always from query input)
21        self.W_Q = nn.Linear(d_model, d_model, bias=bias)
22
23        # Key/Value projections (from key/value inputs)
24        self.W_K = nn.Linear(d_model, d_model, bias=bias)
25        self.W_V = nn.Linear(d_model, d_model, bias=bias)
26
27        # Output projection
28        self.W_O = nn.Linear(d_model, d_model, bias=bias)
29
30        self.dropout = nn.Dropout(dropout)
31
32    def forward(
33        self,
34        query: torch.Tensor,
35        key: torch.Tensor = None,
36        value: torch.Tensor = None,
37        mask: torch.Tensor = None,
38        is_self_attention: bool = None
39    ):
40        """
41        Forward pass.
42
43        Args:
44            query: [batch, seq_len_q, d_model]
45            key: [batch, seq_len_k, d_model] (optional, defaults to query)
46            value: [batch, seq_len_k, d_model] (optional, defaults to key)
47            mask: Attention mask
48            is_self_attention: Explicit flag (auto-detected if None)
49
50        Returns:
51            output: [batch, seq_len_q, d_model]
52            weights: [batch, num_heads, seq_len_q, seq_len_k]
53        """
54        # Auto-detect mode if not specified
55        if is_self_attention is None:
56            is_self_attention = (key is None) or (key is query)
57
58        # Default key/value to query for self-attention
59        if key is None:
60            key = query
61        if value is None:
62            value = key
63
64        batch_size, seq_len_q, _ = query.shape
65        seq_len_k = key.size(1)
66
67        # Project
68        Q = self.W_Q(query)
69        K = self.W_K(key)
70        V = self.W_V(value)
71
72        # Split heads
73        Q = self._split_heads(Q, batch_size)
74        K = self._split_heads(K, batch_size)
75        V = self._split_heads(V, batch_size)
76
77        # Attention
78        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
79
80        if mask is not None:
81            scores = scores.masked_fill(mask == 0, float('-inf'))
82
83        weights = F.softmax(scores, dim=-1)
84        weights = torch.nan_to_num(weights, nan=0.0)
85        weights = self.dropout(weights)
86
87        # Output
88        out = torch.matmul(weights, V)
89        out = self._combine_heads(out, batch_size)
90        out = self.W_O(out)
91
92        return out, weights
93
94    def _split_heads(self, x, batch_size):
95        return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
96
97    def _combine_heads(self, x, batch_size):
98        return x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
99
100
101# Usage examples
102mha = FlexibleMultiHeadAttention(d_model=512, num_heads=8)
103
104# Self-attention (three equivalent ways)
105x = torch.randn(2, 10, 512)
106out1, _ = mha(x)                    # Auto-detect: key=None
107out2, _ = mha(x, x, x)              # Explicit: Q=K=V
108out3, _ = mha(x, is_self_attention=True)  # Flag
109
110# Cross-attention
111dec = torch.randn(2, 8, 512)   # Decoder state
112enc = torch.randn(2, 12, 512)  # Encoder output
113out4, weights = mha(dec, enc, enc)  # Q from dec, K/V from enc
114print(f"Cross-attention weights: {weights.shape}")  # [2, 8, 8, 12]

When to Use Each

Use Self-Attention When:

ScenarioExample
Encoding a sequenceUnderstanding source sentence
Language modelingPredicting next word (causal)
Bidirectional contextBERT-style encoding
Single sequence tasksClassification, NER

Use Cross-Attention When:

ScenarioExample
Sequence-to-sequenceTranslation, summarization
Multimodal fusionImage + text
Retrieval augmentationQuery + retrieved docs
Encoder-decoder modelsT5, BART

Practical Considerations

Memory Efficiency

Self-attention: O(nยฒ) where n is sequence length

Cross-attention: O(n ร— m) where n is query length, m is key length

For cross-attention with long encoder output:

  • Consider sparse attention patterns
  • Use chunked processing
  • KV caching for generation

Caching for Generation

During autoregressive generation, encoder output doesn't change:

CachedCrossAttention
๐Ÿcached_attention.py
20 lines without explanation
1class CachedCrossAttention(nn.Module):
2    def __init__(self, base_attention):
3        super().__init__()
4        self.attention = base_attention
5        self.cached_K = None
6        self.cached_V = None
7
8    def forward(self, query, encoder_output=None, use_cache=True):
9        if encoder_output is not None:
10            # First call: compute and cache K, V
11            self.cached_K = self.attention.W_K(encoder_output)
12            self.cached_V = self.attention.W_V(encoder_output)
13
14        if use_cache and self.cached_K is not None:
15            # Use cached K, V
16            Q = self.attention.W_Q(query)
17            # ... attention with cached K, V
18        else:
19            # Normal forward
20            return self.attention(query, encoder_output, encoder_output)

Summary

Key Differences

AspectSelf-AttentionCross-Attention
Q sourceSame as K, VDifferent from K, V
Matrix shapeSquare (nร—n)Rectangular (nร—m)
Typical useEncoder, decoder self-attnEncoder-decoder bridge
MaskingCausal for decoderSource padding only
What it learnsIntra-sequence relationsInter-sequence relations

Implementation Insight

The same module handles bothโ€”the distinction is purely in what inputs you provide:

Self vs Cross Attention Usage
๐Ÿattention.py
5 lines without explanation
1# Self-attention
2output = mha(x, x, x)
3
4# Cross-attention
5output = mha(decoder_state, encoder_output, encoder_output)

Exercises

Conceptual Questions

  1. Why doesn't cross-attention need a causal mask?
  2. In a translation model, what does high cross-attention weight between "dog" (English) and "Hund" (German) indicate?
  3. Could you use cross-attention between two unrelated sequences? What would the model learn?

Implementation Exercises

  1. Implement a KV-caching wrapper for cross-attention that avoids recomputing encoder projections.
  2. Create a visualization that shows self-attention vs cross-attention patterns for a translation example.
  3. Implement "bi-directional cross-attention" where both sequences query each other.

Chapter Summary

In this chapter on Multi-Head Attention, you learned:

  1. Why multiple heads: Specialization for different relationship types
  2. Linear projections: W_Q, W_K, W_V transform inputs to Q, K, V
  3. Shape transformations: split_heads and combine_heads for parallel computation
  4. Complete implementation: Production-ready MultiHeadAttention module
  5. Self vs cross attention: Same mechanism, different input patterns

You now have a complete, reusable multi-head attention module that forms the core of every transformer layer.


Next Chapter Preview

In Chapter 4: Positional Encoding and Embeddings, we'll solve the "position problem"โ€”transformers are permutation invariant, but language is order-dependent. We'll implement:

  • Token embeddings
  • Sinusoidal positional encoding
  • Learned positional embeddings
  • Combined embedding layers

This will complete the input processing pipeline before we move on to building full encoder and decoder layers.