Chapter 7
15 min read
Section 34 of 104

Multi-Head Attention with 8 Heads

Multi-Head Self-Attention

Learning Objectives

By the end of this section, you will:

  1. Understand why multiple attention heads improve representation learning
  2. Master the multi-head attention formula with projection and concatenation
  3. Trace parallel head computation through all steps
  4. Configure 8 heads for our 256-dimensional BiLSTM output
  5. Calculate the parameter count for multi-head attention
Why This Matters: A single attention head learns one type of query-key relationship. Different aspects of degradation may require different attention patterns—one head might focus on sudden spikes, another on gradual trends. Multi-head attention enables learning these complementary patterns simultaneously.

Why Multiple Heads?

A single attention head is limited to learning one attention pattern. Multiple heads provide diversity.

Limitation of Single Head

With one head, the model must choose one way to relate positions:

  • Focus on recent timesteps? Or distant ones?
  • Attend to similar values? Or contrasting values?
  • Track one sensor pattern? Or another?

A single head cannot do all of these simultaneously.

Multi-Head Diversity

With 8 heads, different heads can specialize:

HeadPossible Specialization
Head 1Attend to recent timesteps (t-1, t-2, t-3)
Head 2Attend to transition points (stable → degrading)
Head 3Track temperature-related patterns
Head 4Track pressure-related patterns
Head 5Focus on spike events
Head 6Focus on gradual trend changes
Head 7Compare early vs late window regions
Head 8Uniform attention for smoothing

Learned Specialization

The specializations above are illustrative. In practice, heads learn their patterns from data. Analyzing trained attention weights often reveals interpretable specializations, but they emerge through training rather than being hand-designed.


Multi-Head Attention Formula

Multi-head attention runs several attention mechanisms in parallel and combines their outputs.

Formula

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O

Where each head is:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)

Components

  • WiQ,WiKRdmodel×dkW^Q_i, W^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}: Per-head query and key projections
  • WiVRdmodel×dvW^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}: Per-head value projection
  • WORhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}: Output projection (combines all heads)
  • hh: Number of heads
  • dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h: Per-head dimensions (typically equal)

Dimension Flow

📝text
1Input X: (T, d_model) = (30, 256)
23    ┌───────────────┼───────────────┐
4    ↓               ↓               ↓
5  Head 1          Head 2    ...   Head 8
6    ↓               ↓               ↓
7(30, 32)        (30, 32)        (30, 32)   ← Each head: d_v = 256/8 = 32
8    ↓               ↓               ↓
9    └───────────────┼───────────────┘
1011              Concatenate
1213            (30, 256)  ← 8 × 32 = 256
1415              W^O projection
1617            Output: (30, 256)

Parallel Head Computation

Each head performs independent attention with its own projections.

Step-by-Step for One Head

📝text
1For head i:
2
31. Project to head's subspace:
4   Q_i = X · W^Q_i    (30, 256) × (256, 32) → (30, 32)
5   K_i = X · W^K_i    (30, 256) × (256, 32) → (30, 32)
6   V_i = X · W^V_i    (30, 256) × (256, 32) → (30, 32)
7
82. Compute scaled dot-product attention:
9   scores_i = Q_i · K_i^T / √32    (30, 32) × (32, 30) → (30, 30)
10   A_i = softmax(scores_i)         (30, 30)
11   head_i = A_i · V_i              (30, 30) × (30, 32) → (30, 32)

Combining Heads

📝text
13. Concatenate all heads:
2   concat = [head_1; head_2; ...; head_8]
3          = [(30, 32); (30, 32); ...; (30, 32)]
4          = (30, 256)
5
64. Final projection:
7   output = concat · W^O    (30, 256) × (256, 256) → (30, 256)

Why Final Projection?

The output projection WOW^O serves several purposes:

  • Combines head information: Allows heads to interact
  • Learns relative importance: Which head outputs matter more
  • Matches dimension: Ensures output matches input for residual connection

Our Configuration: 8 Heads

For the AMNL model, we use 8 attention heads on the 256-dimensional BiLSTM output.

Dimension Breakdown

ParameterValueCalculation
d_model256BiLSTM output dimension
h (num_heads)8Design choice
d_k (head dimension)32256 / 8 = 32
d_v (value dimension)32Same as d_k
Scaling factor√32 ≈ 5.66For stable softmax

Why 8 Heads?

  • Divisibility: 256 / 8 = 32 (clean division)
  • Sufficient diversity: 8 different attention patterns
  • Not too many: Each head has meaningful capacity (32 dims)
  • Standard choice: Common in Transformer literature

Parameter Count

Let us calculate the number of learnable parameters in multi-head attention.

Projection Matrices

ComponentShapeParameters
W^Q (all heads)(256, 256)65,536
W^K (all heads)(256, 256)65,536
W^V (all heads)(256, 256)65,536
W^O (output)(256, 256)65,536

Efficient Implementation

Although conceptually we have 8 separate W^Q_i matrices of size (256, 32), implementations typically use a single W^Q of size (256, 256) and reshape/split for efficiency. The parameter count is identical.

Bias Terms

ComponentShapeParameters
b^Q(256,)256
b^K(256,)256
b^V(256,)256
b^O(256,)256

Total Parameters


Summary

In this section, we designed multi-head attention with 8 heads:

  1. Why multi-head: Different heads learn complementary attention patterns
  2. Formula: Concat(head₁, ..., head₈) × W^O
  3. Per-head computation: Independent Q, K, V projections and attention
  4. Our configuration: 8 heads × 32 dimensions = 256
  5. Parameters: ~263K (4 projection matrices + biases)
PropertyValue
Input dimension256 (from BiLSTM)
Number of heads8
Per-head dimension32
Output dimension256
Scaling factor√32 ≈ 5.66
Total parameters~263K
Looking Ahead: Multi-head attention produces outputs of the same dimension as inputs. This enables residual connections—adding the attention output to the original input. The next section explains why residual connections are essential for deep network training.

With multi-head attention designed, we now examine residual connections that enable deeper architectures.