Chapter 10
11 min read
Section 40 of 121

Multi-Head Attention with 8 Heads

Multi-Head Self-Attention

A Panel of Specialists

One attention head learns one kind of relationship - say, “late cycles attend to early cycles when there is a spike”. Multi-head attention runs H parallel attention computations with separate Q/K/V projections, then concatenates and re-projects the results. Each head can specialise in a different relationship; the concat lets the downstream layer use all of them.

Eight heads is the standard choice from Vaswani et al. 2017 and what the paper's reference architecture uses. With dmodel=512d_{\text{model}} = 512 each head sees dk=dmodel/H=64d_k = d_{\text{model}} / H = 64 dims.

Mental model. Eight separate attention sub-experts, each with its own Q/K/V, voting on what each cycle's output should look like.

The Math: H Parallel Heads

Per head h=1,,Hh = 1, \ldots, H:

headh=Attention ⁣(XWhQ,XWhK,XWhV)\text{head}_h = \text{Attention}\!\bigl(\mathbf{X} W^Q_h,\, \mathbf{X} W^K_h,\, \mathbf{X} W^V_h\bigr)

with each WhQ,WhK,WhVRdmodel×dkW^Q_h, W^K_h, W^V_h \in \mathbb{R}^{d_{\text{model}} \times d_k}. Then concatenate and project:

MHA(X)=Concat(head1,,headH)WO\text{MHA}(\mathbf{X}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) \, W^O

with WORdmodel×dmodelW^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}. In practice we don't maintain H separate projection matrices - we use ONE combined WQRdmodel×dmodelW^Q \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}} and split the result into heads via reshape (as the Python code below shows). Faster on GPU.

Parameter Accounting

ComponentShapeParams
W^Q(512, 512)262,144 + 512 bias = 262,656
W^K(512, 512)262,656
W^V(512, 512)262,656
W^O(512, 512)262,656
Total1,050,624
About 1.05M parameters - second-largest component of the backbone after the BiLSTM. Add layer-norm and a tiny FFN (Section 10.3) and you have ~1.1M total for the attention block.

Python: Multi-Head From Scratch

Eight heads via reshape + transpose - one matmul each
🐍multi_head_attention_numpy.py
1import numpy as np

Standard alias.

4def softmax_rowwise(x):

Same numerically stable softmax from §10.1.

10def multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads):

End-to-end multi-head attention. ONE forward pass; H heads in parallel via reshape + transpose.

EXECUTION STATE
input: X (B, T, d_model) = BiLSTM output
input: Wq, Wk, Wv (d_model, d_model) = Combined projections - the d_k = d_model/H split happens via reshape
input: Wo (d_model, d_model) = Output projection after concat
input: num_heads = 8 in our backbone
returns = (B, T, d_model) - same shape as input
16B, T, d_model = X.shape

Unpack input shape.

17d_k = d_model // num_heads

Per-head dimension. 512 / 8 = 64. Must divide evenly.

EXECUTION STATE
d_k = 64
20Q = (X @ Wq).reshape(B, T, num_heads, d_k).transpose(0, 2, 1, 3)

FOUR ops in one line: (1) project X to d_model dims; (2) reshape to split into H heads of d_k; (3) transpose to put H next to B for broadcast; (4) result has shape (B, H, T, d_k).

EXECUTION STATE
→ why reshape, not separate matmuls? = ONE big matmul (B, T, d_model) @ (d_model, d_model) is faster than H separate (B, T, d_model) @ (d_model, d_k) matmuls. Reshape gives the same logical result.
Q.shape after = (2, 8, 30, 64) = (B, H, T, d_k)
21K = (X @ Wk).reshape(...).transpose(...)

Same trick for K.

22V = (X @ Wv).reshape(...).transpose(...)

Same trick for V.

25scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)

Per-head scaled dot products. K.transpose(0, 1, 3, 2) swaps the LAST TWO axes (T and d_k) so we can matmul. The leading (B, H) axes broadcast.

EXECUTION STATE
scores.shape = (2, 8, 30, 30)
26attn = softmax_rowwise(scores)

Per-head, per-query attention distributions.

27head_outs = attn @ V

Per-head weighted blends.

EXECUTION STATE
head_outs.shape = (2, 8, 30, 64)
30out = head_outs.transpose(0, 2, 1, 3).reshape(B, T, d_model)

REVERSE the head split: bring T next to B, then merge H × d_k back into d_model.

EXECUTION STATE
out.shape = (2, 30, 512) = (B, T, d_model)
31return out @ Wo

FINAL output projection. Maps the concatenated heads (each had d_model bytes of independent computation) back to d_model. Wo lets the model decide how to mix the heads.

35B, T, d_model, H = 2, 30, 512, 8

Backbone shapes.

37X = np.random.randn(B, T, d_model).astype(np.float32) * 0.1

Fake BiLSTM output.

42out = multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads=H)

Run.

44print("X.shape :", X.shape)

Verify input.

EXECUTION STATE
Output = X.shape : (2, 30, 512)
45print("d_k per head:", d_model // H)

64 dims per head.

EXECUTION STATE
Output = d_k per head: 64
46print("out.shape :", out.shape)

Same shape as input.

EXECUTION STATE
Output = out.shape : (2, 30, 512)
30 lines without explanation
1import numpy as np
2
3
4def softmax_rowwise(x):
5    s = x - x.max(axis=-1, keepdims=True)
6    e = np.exp(s)
7    return e / e.sum(axis=-1, keepdims=True)
8
9
10def multi_head_attention(X: np.ndarray,
11                         Wq: np.ndarray,        # (d_model, d_model)
12                         Wk: np.ndarray,
13                         Wv: np.ndarray,
14                         Wo: np.ndarray,         # (d_model, d_model)
15                         num_heads: int) -> np.ndarray:
16    """X: (B, T, d_model). Returns (B, T, d_model)."""
17    B, T, d_model = X.shape
18    d_k = d_model // num_heads
19
20    # Project, then RESHAPE to split the channel dim into heads
21    Q = (X @ Wq).reshape(B, T, num_heads, d_k).transpose(0, 2, 1, 3)  # (B, H, T, d_k)
22    K = (X @ Wk).reshape(B, T, num_heads, d_k).transpose(0, 2, 1, 3)
23    V = (X @ Wv).reshape(B, T, num_heads, d_k).transpose(0, 2, 1, 3)
24
25    # Per-head attention
26    scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)              # (B, H, T, T)
27    attn   = softmax_rowwise(scores)
28    head_outs = attn @ V                                              # (B, H, T, d_k)
29
30    # Concat heads back: (B, H, T, d_k) → (B, T, d_model)
31    out = head_outs.transpose(0, 2, 1, 3).reshape(B, T, d_model)
32    return out @ Wo                                                   # (B, T, d_model)
33
34
35# ----- Run on a (B, 30, 512) sequence with 8 heads -----
36np.random.seed(0)
37B, T, d_model, H = 2, 30, 512, 8
38
39X  = np.random.randn(B, T, d_model).astype(np.float32) * 0.1
40Wq = np.random.randn(d_model, d_model).astype(np.float32) * (1 / np.sqrt(d_model))
41Wk = np.random.randn(d_model, d_model).astype(np.float32) * (1 / np.sqrt(d_model))
42Wv = np.random.randn(d_model, d_model).astype(np.float32) * (1 / np.sqrt(d_model))
43Wo = np.random.randn(d_model, d_model).astype(np.float32) * (1 / np.sqrt(d_model))
44
45out = multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads=H)
46
47print("X.shape   :", X.shape)         # (2, 30, 512)
48print("d_k per head:", d_model // H)   # 64
49print("out.shape :", out.shape)        # (2, 30, 512)

PyTorch: nn.MultiheadAttention

Single Module - one call, all heads, output projection
🐍multi_head_attention_torch.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

Layer container.

4torch.manual_seed(0)

Determinism.

6mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True, dropout=0.1)

Single Module wraps Q/K/V projections, head split, attention, concat, and output projection.

EXECUTION STATE
embed_dim=512 = Input/output channel dim - matches BiLSTM output
num_heads=8 = Eight parallel heads, each seeing d_model/8 = 64
batch_first=True = (B, T, F) order - book convention
dropout=0.1 = Applied to attention weights AND output
13x = torch.randn(2, 30, 512)

Fake BiLSTM output.

16y, attn_w = mha(x, x, x, need_weights=True)

Self-attention: same tensor as Q, K, AND V. need_weights=True returns the (B, T, T) attention map for inspection.

EXECUTION STATE
→ can also do cross-attention = mha(query, key, value) with different tensors. Useful in encoder-decoder; we don't need it here.
18print("input :", tuple(x.shape))

Verify input.

EXECUTION STATE
Output = input : (2, 30, 512)
19print("output :", tuple(y.shape))

Same shape as input.

EXECUTION STATE
Output = output : (2, 30, 512)
20print("attn_w :", tuple(attn_w.shape))

Attention map. Default is averaged across heads; pass average_attn_weights=False to get per-head (B, H, T, T).

EXECUTION STATE
Output = attn_w : (2, 30, 30)
21print("# params:", sum(p.numel() for p in mha.parameters()))

Four (d_model, d_model) projections + biases. 4 × (512×512 + 512) = 1,050,624.

EXECUTION STATE
Output = # params: 1,050,624
12 lines without explanation
1import torch
2import torch.nn as nn
3
4torch.manual_seed(0)
5
6mha = nn.MultiheadAttention(
7    embed_dim=512,
8    num_heads=8,
9    batch_first=True,
10    dropout=0.1,
11)
12
13x = torch.randn(2, 30, 512)               # BiLSTM output
14
15# self-attention: Q, K, V all from x
16y, attn_w = mha(x, x, x, need_weights=True)
17
18print("input  :", tuple(x.shape))           # (2, 30, 512)
19print("output :", tuple(y.shape))            # (2, 30, 512)
20print("attn_w :", tuple(attn_w.shape))       # (2, 30, 30)
21print("# params:", sum(p.numel() for p in mha.parameters()))
22# # params: 1,050,624 (= 4 * (512*512 + 512))

How Many Heads Across Architectures

Architectured_modelHeadsd_k
RUL backbone (this book)512864
Original Transformer512864
BERT-base7681264
BERT-large10241664
GPT-3 (small)7681264
ViT-base7681264
AlphaFold 2 evoformer256832
The d_k = 64 lottery. Almost every transformer in the literature picks d_model and num_heads such that dk=64d_k = 64. This is roughly the size at which softmax variance behaves nicely with the dk\sqrt{d_k} scaling - empirically it is the sweet spot.

Two Multi-Head Pitfalls

Pitfall 1: d_model not divisible by H. 512 / 7 is not an integer; nn.MultiheadAttention will raise an error. Always ensure d_model % num_heads == 0.
Pitfall 2: Forgetting need_weights flag. By default nn.MultiheadAttention returns(out, None) - attention weights are NOT computed unless need_weights=True. Saves compute but breaks interpretability code that expected the weights tensor.
The point. Eight parallel attention heads running on the same input give the model eight different kinds of cross-cycle relationships to combine. The concat + output projection (W^O) lets the network choose how to weight the heads.

Takeaway

  • Eight heads, d_k = 64. Standard Vaswani choice; the paper uses it too.
  • Use ONE combined W^Q matrix, reshape to split into heads. Faster than maintaining 8 separate matrices.
  • ~1.05M parameters. Second-biggest backbone component after the BiLSTM.
  • Output is (B, T, d_model) - same shape as input. Fits cleanly between BiLSTM and the FC stack.
Loading comments...