Chapter 3
10 min read
Section 14 of 75

Why Multiple Heads?

Multi-Head Attention

Introduction

In the previous chapter, we implemented single-head attention. It works well, but the original Transformer paper uses multi-head attention instead. Why?

This section explores the theoretical motivation for using multiple attention headsβ€”how they allow the model to capture different types of relationships simultaneously and why this matters for language understanding.


Limitations of Single-Head Attention

The Problem: One Perspective Only

Single-head attention computes one set of attention weights:

πŸ“text
1For "The cat sat on the mat":
2
3Query "sat" β†’ Attention weights β†’ [0.05, 0.35, 0.15, 0.10, 0.05, 0.30]
4                                   The  cat   sat   on   the   mat

But language has multiple types of relationships:

  1. Syntactic: "sat" needs its subject ("cat") and object ("mat")
  2. Semantic: "cat" and "mat" might be related (domestic context)
  3. Positional: Adjacent words often matter for grammar
  4. Long-range: Pronouns need to find their referents

Can one attention pattern capture all of this? Not effectively.

Mathematical Limitation

Single-head attention projects into one subspace:

egin{aligned} Q &= X cdot W_Q quad ightarrow quad ext{Single query representation} \\ K &= X cdot W_K quad ightarrow quad ext{Single key representation} end{aligned}

This single projection must balance:

  • Finding syntactic dependencies
  • Capturing semantic similarity
  • Maintaining positional relationships

One projection can't optimize for all of these simultaneously.

The "Averaging" Problem

When a single head tries to capture multiple relationship types, it compromises:

πŸ“text
1Ideal for syntax:     [0.0, 0.8, 0.0, 0.0, 0.0, 0.2]  (subject + object)
2Ideal for position:   [0.3, 0.3, 0.2, 0.2, 0.0, 0.0]  (nearby words)
3Ideal for semantics:  [0.0, 0.3, 0.0, 0.0, 0.0, 0.7]  (related concepts)
4
5Compromise (averaged): [0.1, 0.5, 0.1, 0.1, 0.0, 0.3]  (muddy signal)

The compromise is worse than any specialized pattern.


The Multi-Head Solution

Core Idea: Multiple Perspectives

Instead of one attention pattern, compute multiple patterns in parallel:

πŸ“text
1Head 1: Focuses on syntactic relationships
2Head 2: Focuses on semantic similarity
3Head 3: Focuses on adjacent positions
4Head 4: Focuses on long-range dependencies
5...

Each head can specialize for different relationship types.

Visual Representation

πŸ“text
1Input: "The cat sat on the mat"
2
3Head 1 (Syntactic):     sat β†’ [Β·, β– , Β·, Β·, Β·, β– ]  (subject, object)
4Head 2 (Positional):    sat β†’ [Β·, β– , β– , β– , Β·, Β·]  (nearby words)
5Head 3 (Semantic):      cat β†’ [Β·, β– , Β·, Β·, Β·, β– ]  (cat ↔ mat, domestic)
6Head 4 (Article):       cat β†’ [β– , β– , Β·, Β·, Β·, Β·]  (article-noun pair)
7
8Concatenate all heads β†’ Rich, multi-faceted representation

The Mathematical Framework

Multi-head attention:

extMultiHead(Q,K,V)=extConcat(exthead1,…,extheadh)β‹…WOext{MultiHead}(Q, K, V) = ext{Concat}( ext{head}_1, \ldots, ext{head}_h) \cdot W^O

Where each head is computed as:

extheadi=extAttention(Qβ‹…WiQ,β€…β€ŠKβ‹…WiK,β€…β€ŠVβ‹…WiV)ext{head}_i = ext{Attention}(Q \cdot W^Q_i, \; K \cdot W^K_i, \; V \cdot W^V_i)

Each head has its own projection matrices:

  • WiQW^Q_i: What each head looks for
  • WiKW^K_i: What each position advertises
  • WiVW^V_i: What information to extract

Specialization of Attention Heads

Empirical Observations

Research analyzing trained transformers reveals that heads do specialize:

Head TypeWhat It LearnsExample Pattern
PositionalAttend to adjacent tokensDiagonal or band pattern
SyntacticSubject-verb, verb-objectSparse, specific connections
SemanticSimilar/related conceptsCluster-like patterns
DelimiterAttend to [CLS], [SEP], periodsColumn pattern
Previous TokenAlways attend to position i-1Shifted diagonal
Rare TokenAttend to unusual/important wordsSparse, content-specific

Visualization from Research

BERT attention analysis shows distinct patterns across heads:

πŸ“text
1Layer 1, Head 1:           Layer 1, Head 2:           Layer 1, Head 3:
2  [β–  Β· Β· Β· Β·]                [β–  β–  Β· Β· Β·]                [Β· Β· Β· Β· β– ]
3  [β–  β–  Β· Β· Β·]                [Β· β–  β–  Β· Β·]                [Β· Β· Β· Β· β– ]
4  [Β· β–  β–  Β· Β·]                [Β· Β· β–  β–  Β·]                [Β· Β· Β· Β· β– ]
5  [Β· Β· β–  β–  Β·]                [Β· Β· Β· β–  β– ]                [Β· Β· Β· Β· β– ]
6  [Β· Β· Β· β–  β– ]                [Β· Β· Β· Β· β– ]                [Β· Β· Β· Β· β– ]
7
8  Previous token            Next token               Last position

The Redundancy-Expressiveness Tradeoff

Not all heads are equally useful:

  • Some heads learn similar patterns (redundancy)
  • Some heads learn seemingly random patterns (noise)
  • Some heads are critical (pruning them hurts performance)

This observation has led to:

  • Head pruning: Removing less important heads
  • Adaptive attention: Learning which heads to use
  • Mixture of experts: Routing to relevant heads

Subspace Projection Intuition

High-Dimensional Embedding Space

Word embeddings live in high-dimensional space (e.g., 512 dimensions).

Different dimensions encode different information:

  • Dimensions 1-50: Maybe gender/animacy
  • Dimensions 51-100: Maybe part-of-speech
  • Dimensions 101-200: Maybe semantic category
  • ...

Projecting into Subspaces

Each attention head projects into a subspace:

egin{aligned} d_{ ext{model}} &= 512 quad ext{(full embedding dimension)} \\ n_{ ext{heads}} &= 8 quad ext{(number of attention heads)} \\ d_k &= rac{d_{ ext{model}}}{n_{ ext{heads}}} = rac{512}{8} = 64 quad ext{(per-head dimension)} end{aligned}

Head 1: Projects into 64-dim subspace β†’ Captures one aspect
Head 2: Projects into different 64-dim subspace β†’ Captures another aspect

The projection matrices WiQW^Q_i, WiKW^K_i learn which aspects to focus on.

Geometric Intuition

Imagine embeddings in 3D (simplified):

πŸ“text
1Full 3D space:
2  Words live here with multiple properties
3
4Head 1 projects onto XY plane:
5  Captures relationships in X-Y dimensions
6  (e.g., syntactic features)
7
8Head 2 projects onto XZ plane:
9  Captures relationships in X-Z dimensions
10  (e.g., semantic features)
11
12Combining both:
13  Richer understanding than either alone

One Large Head vs Multiple Small Heads

Parameter Comparison

One large head (dk=dextmodel=512d_k = d_{ ext{model}} = 512):

egin{aligned} W^Q &: [512 imes 512] ightarrow 262{,}144 ext{ parameters} \\ W^K &: [512 imes 512] ightarrow 262{,}144 ext{ parameters} \\ W^V &: [512 imes 512] ightarrow 262{,}144 ext{ parameters} \\ \hline extbf{Total} &: 786{,}432 ext{ parameters} end{aligned}

Eight small heads (dk=64d_k = 64 each):

egin{aligned} W^Q &: [512 imes 512] ightarrow 262{,}144 ext{ parameters (projects to all heads)} \\ W^K &: [512 imes 512] ightarrow 262{,}144 ext{ parameters} \\ W^V &: [512 imes 512] ightarrow 262{,}144 ext{ parameters} \\ W^O &: [512 imes 512] ightarrow 262{,}144 ext{ parameters (output projection)} \\ \hline extbf{Total} &: 1{,}048{,}576 ext{ parameters} end{aligned}

Multi-head uses ~33% more parameters (for WOW^O), but gains:

  • Multiple attention patterns
  • Specialization capability
  • Better gradient flow (parallel heads)

Empirical Results

The original paper compared configurations:

ConfigBLEU ScoreNote
1 head, dk=512d_k = 51224.9Baseline
8 heads, dk=64d_k = 6425.8+0.9 improvement
16 heads, dk=32d_k = 3225.5Diminishing returns
32 heads, dk=16d_k = 1625.0Too many heads

Sweet spot: 8-16 heads for typical models.


How Heads Learn to Specialize

Random Initialization Breaks Symmetry

At initialization:

  • All heads have random weights
  • Random differences lead to different gradients
  • Heads diverge during training

Specialization Emerges from Data

The data contains different types of patterns:

  • Syntactic patterns β†’ Some heads specialize here
  • Semantic patterns β†’ Other heads specialize here
  • The loss function rewards useful specialization

Not Explicitly Designed

Heads are NOT told what to specialize in:

  • No "syntax head" label
  • No "semantic head" constraint
  • Specialization emerges from training

This is learned, not engineered.


Attention Head Redundancy

The Dropout Perspective

If we apply dropout to attention heads:

  • Randomly zero out entire heads during training
  • Model must learn redundant representations
  • No single head becomes critical

Head Pruning Research

Studies show:

  • 30-40% of heads can be removed with minimal loss
  • Some heads are redundant (similar patterns)
  • Some heads are critical (removing them hurts a lot)

Implications

  1. Robustness: Multiple heads provide backup
  2. Efficiency opportunity: Can prune for inference
  3. Interpretability challenge: Hard to assign meaning to heads

Multi-Head Attention in Different Layers

Layer-Dependent Patterns

Different layers learn different patterns:

Early Layers (1-4):

  • Local patterns (adjacent tokens)
  • Syntactic patterns (POS, grammar)
  • Surface-level features

Middle Layers (5-8):

  • Longer-range dependencies
  • Semantic relationships
  • Entity tracking

Late Layers (9-12):

  • Task-specific patterns
  • Abstract representations
  • Output preparation

Why This Matters

Understanding layer roles helps with:

  • Transfer learning: Which layers to freeze
  • Interpretability: Where to look for specific patterns
  • Efficiency: Which layers are most important

Summary

Why Multi-Head Attention?

Single HeadMultiple Heads
One attention patternMultiple patterns
One subspace projectionMultiple subspaces
Must compromiseCan specialize
Less expressiveMore expressive
Lower parameter countSlightly more parameters

Key Takeaways

  1. Specialization: Heads learn to capture different relationship types
  2. Parallel Subspaces: Each head projects to a different subspace
  3. Empirical Benefit: Multi-head consistently outperforms single-head
  4. Emergent Behavior: Specialization emerges from training, not design
  5. Redundancy: Some heads are redundant, others are critical

The Formula Preview

Multi-Head Attention:

extMultiHead(Q,K,V)=extConcat(exthead1,…,extheadh)β‹…WOext{MultiHead}(Q, K, V) = ext{Concat}( ext{head}_1, \ldots, ext{head}_h) \cdot W^O

Each head:

extheadi=extAttention(Qβ‹…WiQ,β€…β€ŠKβ‹…WiK,β€…β€ŠVβ‹…WiV)ext{head}_i = ext{Attention}(Q \cdot W^Q_i, \; K \cdot W^K_i, \; V \cdot W^V_i)

Exercises

Conceptual Questions

  1. Explain in your own words why a single attention head can't simultaneously optimize for syntax and semantics.
  2. If you had 1024-dimensional embeddings and 16 heads, what would be the dimension per head (dkd_k)? Why must dextmodeld_{ ext{model}} be divisible by nextheadsn_{ ext{heads}}?
  3. Why might increasing heads beyond 16 show diminishing returns or even hurt performance?
  4. If all heads learned identical patterns, what would be lost compared to diverse specialization?

Thought Experiments

  1. Design an experiment to test whether attention heads specialize. What would you measure?
  2. If you could manually assign head functions (Head 1 = syntax, Head 2 = semantics, etc.), how would you do it? Why might learned specialization be better?
  3. How might the optimal number of heads change for:
    • A very small model (dextmodel=128d_{ ext{model}} = 128)?
    • A very large model (dextmodel=4096d_{ ext{model}} = 4096)?
    • A model for very long sequences?