Learning Objectives
By the end of this section, you will:
- Understand why multiple attention heads improve representation learning
- Master the multi-head attention formula with projection and concatenation
- Trace parallel head computation through all steps
- Configure 8 heads for our 256-dimensional BiLSTM output
- Calculate the parameter count for multi-head attention
Why This Matters: A single attention head learns one type of query-key relationship. Different aspects of degradation may require different attention patterns—one head might focus on sudden spikes, another on gradual trends. Multi-head attention enables learning these complementary patterns simultaneously.
Why Multiple Heads?
A single attention head is limited to learning one attention pattern. Multiple heads provide diversity.
Limitation of Single Head
With one head, the model must choose one way to relate positions:
- Focus on recent timesteps? Or distant ones?
- Attend to similar values? Or contrasting values?
- Track one sensor pattern? Or another?
A single head cannot do all of these simultaneously.
Multi-Head Diversity
With 8 heads, different heads can specialize:
| Head | Possible Specialization |
|---|---|
| Head 1 | Attend to recent timesteps (t-1, t-2, t-3) |
| Head 2 | Attend to transition points (stable → degrading) |
| Head 3 | Track temperature-related patterns |
| Head 4 | Track pressure-related patterns |
| Head 5 | Focus on spike events |
| Head 6 | Focus on gradual trend changes |
| Head 7 | Compare early vs late window regions |
| Head 8 | Uniform attention for smoothing |
Learned Specialization
The specializations above are illustrative. In practice, heads learn their patterns from data. Analyzing trained attention weights often reveals interpretable specializations, but they emerge through training rather than being hand-designed.
Multi-Head Attention Formula
Multi-head attention runs several attention mechanisms in parallel and combines their outputs.
Formula
Where each head is:
Components
- : Per-head query and key projections
- : Per-head value projection
- : Output projection (combines all heads)
- : Number of heads
- : Per-head dimensions (typically equal)
Dimension Flow
1Input X: (T, d_model) = (30, 256)
2 ↓
3 ┌───────────────┼───────────────┐
4 ↓ ↓ ↓
5 Head 1 Head 2 ... Head 8
6 ↓ ↓ ↓
7(30, 32) (30, 32) (30, 32) ← Each head: d_v = 256/8 = 32
8 ↓ ↓ ↓
9 └───────────────┼───────────────┘
10 ↓
11 Concatenate
12 ↓
13 (30, 256) ← 8 × 32 = 256
14 ↓
15 W^O projection
16 ↓
17 Output: (30, 256)Parallel Head Computation
Each head performs independent attention with its own projections.
Step-by-Step for One Head
1For head i:
2
31. Project to head's subspace:
4 Q_i = X · W^Q_i (30, 256) × (256, 32) → (30, 32)
5 K_i = X · W^K_i (30, 256) × (256, 32) → (30, 32)
6 V_i = X · W^V_i (30, 256) × (256, 32) → (30, 32)
7
82. Compute scaled dot-product attention:
9 scores_i = Q_i · K_i^T / √32 (30, 32) × (32, 30) → (30, 30)
10 A_i = softmax(scores_i) (30, 30)
11 head_i = A_i · V_i (30, 30) × (30, 32) → (30, 32)Combining Heads
13. Concatenate all heads:
2 concat = [head_1; head_2; ...; head_8]
3 = [(30, 32); (30, 32); ...; (30, 32)]
4 = (30, 256)
5
64. Final projection:
7 output = concat · W^O (30, 256) × (256, 256) → (30, 256)Why Final Projection?
The output projection serves several purposes:
- Combines head information: Allows heads to interact
- Learns relative importance: Which head outputs matter more
- Matches dimension: Ensures output matches input for residual connection
Our Configuration: 8 Heads
For the AMNL model, we use 8 attention heads on the 256-dimensional BiLSTM output.
Dimension Breakdown
| Parameter | Value | Calculation |
|---|---|---|
| d_model | 256 | BiLSTM output dimension |
| h (num_heads) | 8 | Design choice |
| d_k (head dimension) | 32 | 256 / 8 = 32 |
| d_v (value dimension) | 32 | Same as d_k |
| Scaling factor | √32 ≈ 5.66 | For stable softmax |
Why 8 Heads?
- Divisibility: 256 / 8 = 32 (clean division)
- Sufficient diversity: 8 different attention patterns
- Not too many: Each head has meaningful capacity (32 dims)
- Standard choice: Common in Transformer literature
Parameter Count
Let us calculate the number of learnable parameters in multi-head attention.
Projection Matrices
| Component | Shape | Parameters |
|---|---|---|
| W^Q (all heads) | (256, 256) | 65,536 |
| W^K (all heads) | (256, 256) | 65,536 |
| W^V (all heads) | (256, 256) | 65,536 |
| W^O (output) | (256, 256) | 65,536 |
Efficient Implementation
Although conceptually we have 8 separate W^Q_i matrices of size (256, 32), implementations typically use a single W^Q of size (256, 256) and reshape/split for efficiency. The parameter count is identical.
Bias Terms
| Component | Shape | Parameters |
|---|---|---|
| b^Q | (256,) | 256 |
| b^K | (256,) | 256 |
| b^V | (256,) | 256 |
| b^O | (256,) | 256 |
Total Parameters
Summary
In this section, we designed multi-head attention with 8 heads:
- Why multi-head: Different heads learn complementary attention patterns
- Formula: Concat(head₁, ..., head₈) × W^O
- Per-head computation: Independent Q, K, V projections and attention
- Our configuration: 8 heads × 32 dimensions = 256
- Parameters: ~263K (4 projection matrices + biases)
| Property | Value |
|---|---|
| Input dimension | 256 (from BiLSTM) |
| Number of heads | 8 |
| Per-head dimension | 32 |
| Output dimension | 256 |
| Scaling factor | √32 ≈ 5.66 |
| Total parameters | ~263K |
Looking Ahead: Multi-head attention produces outputs of the same dimension as inputs. This enables residual connections—adding the attention output to the original input. The next section explains why residual connections are essential for deep network training.
With multi-head attention designed, we now examine residual connections that enable deeper architectures.