Introduction
Multi-head attention can operate in two modes: self-attention (where query, key, and value come from the same source) and cross-attention (where query comes from one source and key/value from another). Understanding this distinction is crucial for building encoder-decoder transformers.
The Two Modes
Self-Attention
Definition: Query, Key, and Value are all derived from the same input sequence.
Use cases:
- Encoder self-attention
- Decoder self-attention (masked)
What it captures: Relationships within a sequence
- How tokens relate to each other
- Syntactic dependencies (subject-verb)
- Semantic relationships (coreference)
Cross-Attention
Definition: Query comes from one sequence, Key and Value from another.
Use cases:
- Encoder-decoder attention (decoder queries, encoder provides K/V)
- Vision-language models (text queries, image provides K/V)
What it captures: Relationships between sequences
- How decoder tokens relate to encoder tokens
- Alignment in translation (which source word to translate)
- Grounding in multimodal tasks
Visual Comparison
Self-Attention Pattern
Cross-Attention Pattern
Cross-Attention with Different Lengths
Architecture Usage
In the Original Transformer
Data Flow Diagram
Implementation Comparison
Minimal Code Difference
The same module handles both modesโthe only difference is what inputs you pass!
Shape Implications
Masking Differences
Self-Attention Masks
Encoder self-attention: Only padding mask
Decoder self-attention: Causal + padding mask
Cross-Attention Masks
Only mask padding in the source (encoder output):
No causal mask neededโdecoder can look at entire source!
Flexible Attention Module
Here's a module that explicitly handles both modes:
When to Use Each
Use Self-Attention When:
| Scenario | Example |
|---|---|
| Encoding a sequence | Understanding source sentence |
| Language modeling | Predicting next word (causal) |
| Bidirectional context | BERT-style encoding |
| Single sequence tasks | Classification, NER |
Use Cross-Attention When:
| Scenario | Example |
|---|---|
| Sequence-to-sequence | Translation, summarization |
| Multimodal fusion | Image + text |
| Retrieval augmentation | Query + retrieved docs |
| Encoder-decoder models | T5, BART |
Practical Considerations
Memory Efficiency
Self-attention: O(nยฒ) where n is sequence length
Cross-attention: O(n ร m) where n is query length, m is key length
For cross-attention with long encoder output:
- Consider sparse attention patterns
- Use chunked processing
- KV caching for generation
Caching for Generation
During autoregressive generation, encoder output doesn't change:
Summary
Key Differences
| Aspect | Self-Attention | Cross-Attention |
|---|---|---|
| Q source | Same as K, V | Different from K, V |
| Matrix shape | Square (nรn) | Rectangular (nรm) |
| Typical use | Encoder, decoder self-attn | Encoder-decoder bridge |
| Masking | Causal for decoder | Source padding only |
| What it learns | Intra-sequence relations | Inter-sequence relations |
Implementation Insight
The same module handles bothโthe distinction is purely in what inputs you provide:
Exercises
Conceptual Questions
- Why doesn't cross-attention need a causal mask?
- In a translation model, what does high cross-attention weight between "dog" (English) and "Hund" (German) indicate?
- Could you use cross-attention between two unrelated sequences? What would the model learn?
Implementation Exercises
- Implement a KV-caching wrapper for cross-attention that avoids recomputing encoder projections.
- Create a visualization that shows self-attention vs cross-attention patterns for a translation example.
- Implement "bi-directional cross-attention" where both sequences query each other.
Chapter Summary
In this chapter on Multi-Head Attention, you learned:
- Why multiple heads: Specialization for different relationship types
- Linear projections: W_Q, W_K, W_V transform inputs to Q, K, V
- Shape transformations: split_heads and combine_heads for parallel computation
- Complete implementation: Production-ready MultiHeadAttention module
- Self vs cross attention: Same mechanism, different input patterns
You now have a complete, reusable multi-head attention module that forms the core of every transformer layer.
Next Chapter Preview
In Chapter 4: Positional Encoding and Embeddings, we'll solve the "position problem"โtransformers are permutation invariant, but language is order-dependent. We'll implement:
- Token embeddings
- Sinusoidal positional encoding
- Learned positional embeddings
- Combined embedding layers
This will complete the input processing pipeline before we move on to building full encoder and decoder layers.