One attention head learns one kind of relationship - say, “late cycles attend to early cycles when there is a spike”. Multi-head attention runs H parallel attention computations with separate Q/K/V projections, then concatenates and re-projects the results. Each head can specialise in a different relationship; the concat lets the downstream layer use all of them.
Eight heads is the standard choice from Vaswani et al. 2017 and what the paper's reference architecture uses. With dmodel=512 each head sees dk=dmodel/H=64 dims.
Mental model. Eight separate attention sub-experts, each with its own Q/K/V, voting on what each cycle's output should look like.
The Math: H Parallel Heads
Per head h=1,…,H:
headh=Attention(XWhQ,XWhK,XWhV)
with each WhQ,WhK,WhV∈Rdmodel×dk. Then concatenate and project:
MHA(X)=Concat(head1,…,headH)WO
with WO∈Rdmodel×dmodel. In practice we don't maintain H separate projection matrices - we use ONE combined WQ∈Rdmodel×dmodel and split the result into heads via reshape (as the Python code below shows). Faster on GPU.
Parameter Accounting
Component
Shape
Params
W^Q
(512, 512)
262,144 + 512 bias = 262,656
W^K
(512, 512)
262,656
W^V
(512, 512)
262,656
W^O
(512, 512)
262,656
Total
—
1,050,624
About 1.05M parameters - second-largest component of the backbone after the BiLSTM. Add layer-norm and a tiny FFN (Section 10.3) and you have ~1.1M total for the attention block.
Python: Multi-Head From Scratch
Eight heads via reshape + transpose - one matmul each
FOUR ops in one line: (1) project X to d_model dims; (2) reshape to split into H heads of d_k; (3) transpose to put H next to B for broadcast; (4) result has shape (B, H, T, d_k).
EXECUTION STATE
→ why reshape, not separate matmuls? = ONE big matmul (B, T, d_model) @ (d_model, d_model) is faster than H separate (B, T, d_model) @ (d_model, d_k) matmuls. Reshape gives the same logical result.
REVERSE the head split: bring T next to B, then merge H × d_k back into d_model.
EXECUTION STATE
out.shape = (2, 30, 512) = (B, T, d_model)
31return out @ Wo
FINAL output projection. Maps the concatenated heads (each had d_model bytes of independent computation) back to d_model. Wo lets the model decide how to mix the heads.
1import torch
2import torch.nn as nn
34torch.manual_seed(0)56mha = nn.MultiheadAttention(7 embed_dim=512,8 num_heads=8,9 batch_first=True,10 dropout=0.1,11)1213x = torch.randn(2,30,512)# BiLSTM output1415# self-attention: Q, K, V all from x16y, attn_w = mha(x, x, x, need_weights=True)1718print("input :",tuple(x.shape))# (2, 30, 512)19print("output :",tuple(y.shape))# (2, 30, 512)20print("attn_w :",tuple(attn_w.shape))# (2, 30, 30)21print("# params:",sum(p.numel()for p in mha.parameters()))22# # params: 1,050,624 (= 4 * (512*512 + 512))
How Many Heads Across Architectures
Architecture
d_model
Heads
d_k
RUL backbone (this book)
512
8
64
Original Transformer
512
8
64
BERT-base
768
12
64
BERT-large
1024
16
64
GPT-3 (small)
768
12
64
ViT-base
768
12
64
AlphaFold 2 evoformer
256
8
32
The d_k = 64 lottery. Almost every transformer in the literature picks d_model and num_heads such that dk=64. This is roughly the size at which softmax variance behaves nicely with the dk scaling - empirically it is the sweet spot.
Two Multi-Head Pitfalls
Pitfall 1: d_model not divisible by H. 512 / 7 is not an integer; nn.MultiheadAttention will raise an error. Always ensure d_model % num_heads == 0.
Pitfall 2: Forgetting need_weights flag. By default nn.MultiheadAttention returns(out, None) - attention weights are NOT computed unless need_weights=True. Saves compute but breaks interpretability code that expected the weights tensor.
The point. Eight parallel attention heads running on the same input give the model eight different kinds of cross-cycle relationships to combine. The concat + output projection (W^O) lets the network choose how to weight the heads.
Takeaway
Eight heads, d_k = 64. Standard Vaswani choice; the paper uses it too.
Use ONE combined W^Q matrix, reshape to split into heads. Faster than maintaining 8 separate matrices.
~1.05M parameters. Second-biggest backbone component after the BiLSTM.
Output is (B, T, d_model) - same shape as input. Fits cleanly between BiLSTM and the FC stack.