Chapter 3
17 min read
Section 12 of 121

Self-Attention

Mathematical Preliminaries

The Library Analogy

Walk into a library with a question you cannot quite phrase — you know it has something to do with nineteenth-century French opera but you do not remember more. The librarian listens, walks the shelves, and pulls a stack of books that overlap with your fuzzy query. From those books she produces a single answer that is a weighted blend of the most relevant pages.

That is precisely what self-attention does to a sequence. Each timestep produces a query; the system compares it against every timestep's key; the matches determine how much each timestep's value should contribute to the answer. For RUL prediction this lets cycle 27 attend to cycle 5 if cycle 5 carried the early-warning signal — something neither convolution (only sees K cycles) nor a vanilla LSTM (forgets gradually) can reliably do.

The mental model. Self-attention is differentiable lookup. Every cycle reads from every other cycle, with weights the network learns.

Queries, Keys, Values

Given an input sequence XRT×d\mathbf{X} \in \mathbb{R}^{T \times d} (T cycles, d-dim features), three projection matrices produce three new sequences:

Q=XWQ,K=XWK,V=XWV.\mathbf{Q} = \mathbf{X} W^Q, \quad \mathbf{K} = \mathbf{X} W^K, \quad \mathbf{V} = \mathbf{X} W^V.

WQ,WK,WVRd×dkW^Q, W^K, W^V \in \mathbb{R}^{d \times d_k} are learnable. After the projections, Q,KRT×dk\mathbf{Q}, \mathbf{K} \in \mathbb{R}^{T \times d_k} and VRT×dv\mathbf{V} \in \mathbb{R}^{T \times d_v}.

SymbolRoleAnalogy
Q\mathbf{Q}What this position is looking forThe fuzzy query
K\mathbf{K}What this position offersThe book index entry
V\mathbf{V}What this position contributesThe book content

Scaled Dot-Product Attention

Compute pairwise similarity between every query and every key, then normalise:

Attention(Q,K,V)  =  softmax ⁣(QKdk)V.\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) \;=\; \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_k}}\right) \mathbf{V}.

Three things in one expression. (1) QKRT×T\mathbf{Q}\mathbf{K}^{\top} \in \mathbb{R}^{T \times T} gives every query-key dot product. (2) Divide by dk\sqrt{d_k} to keep variance roughly 1 regardless of dkd_k; without this, largedkd_k pushes softmax into a near-one-hot regime where gradients vanish. (3) Softmax row-wise to produce a valid probability distribution, then multiply by V\mathbf{V} to compute the weighted sum.

Why dk\sqrt{d_k}? If components of Q,K\mathbf{Q}, \mathbf{K} are independent unit-variance random variables, the dot product qk\mathbf{q} \cdot \mathbf{k} has variancedkd_k. Dividing by dk\sqrt{d_k} brings it back to unit variance — which keeps the softmax in its useful, non-saturated regime. Vaswani et al. 2017 derived this in the original transformer paper.

Interactive: Pick a Query, See Where It Looks

Below is a 6×66 \times 6 attention map computed on random Q, K, V matrices (matching the NumPy run further down so you can verify the numbers). Click any row to pick a query; the bar chart on the right shows where its probability mass goes. Toggle between raw scores and softmax-normalised attention.

Loading attention heatmap…

Try sliding the scale factor down to 0.5: the softmax row collapses to one-hot — the model fixates on a single key, which kills gradient flow to the others. Push it up to 8: the row becomes near-uniform, and self-attention degenerates into uniform averaging. The choice scale=dk\text{scale} = \sqrt{d_k} is the sweet spot.

Interactive: The Macro Flow

Zooming out, the four-stage flow:

Attention Flow

How information flows through the attention mechanism

1

Q compares with K

QKᵀ

Similarity scores

2

Softmax normalizes

softmax(·)

Weights sum to 1

3

Multiply by V

A × V

Weighted values

4

Output

Context

Enriched representation

Compare
Weight
Combine
Context-aware output

Multi-Head Attention

One attention head learns one kind of relationship. Multi-head attention runs HH attention computations in parallel with separate WhQ,WhK,WhVW^Q_h, W^K_h, W^V_h for each head hh, then concatenates the outputs and projects:

MHA(X)  =  Concat(head1,,headH)WO,\text{MHA}(X) \;=\; \text{Concat}(\text{head}_1, \ldots, \text{head}_H) \, W^O,

with headh=Attention(XWhQ,XWhK,XWhV)\text{head}_h = \text{Attention}(X W^Q_h,\, X W^K_h,\, X W^V_h). Each head sees a lower-dim slice (dk=dmodel/Hd_k = d_{\text{model}} / H) and can specialise: one head might learn “late cycles attend to early cycles” while another learns “adjacent-cycle smoothing”. The paper and our backbone use 8 heads.

Python: Attention From Scratch

Twenty lines of NumPy implement scaled-dot-product attention end to end. The output values match what the interactive heatmap displays when you select query 0 with scale = sqrt(4) = 2.

Scaled-dot-product attention from first principles
🐍attention_numpy.py
1import numpy as np

Same NumPy alias.

4def softmax_rowwise(x):

Numerically stable row-wise softmax. The max-subtract trick keeps exp() in safe range; the result is mathematically identical because softmax is shift-invariant.

EXECUTION STATE
input: x (..., T) = Last-axis softmax. (T, T) for our attention scores.
5shifted = x - x.max(axis=-1, keepdims=True)

Subtract the per-row max. After this, every row's max element is 0; exp() of non-positive numbers stays in (0, 1]. Without this you can hit np.exp(1000) = inf when scores are large.

EXECUTION STATE
axis=-1 = Operate along the LAST axis. Equivalent to axis=1 for a 2-D matrix.
keepdims=True = Result keeps shape (T, 1) instead of (T,) so broadcasting against (T, T) works.
6expd = np.exp(shifted)

Element-wise exponential. With shifted values ≤ 0, expd ∈ (0, 1].

7return expd / expd.sum(axis=-1, keepdims=True)

Normalise each row to sum to 1. That is the softmax property: every row of the output is a valid probability distribution.

10def attention(Q, K, V):

The full scaled-dot-product attention in three lines.

EXECUTION STATE
input: Q (T, d_k) = Query matrix - one query vector per timestep
input: K (T, d_k) = Key matrix - one key vector per timestep
input: V (T, d_v) = Value matrix - one value vector per timestep
12d_k = Q.shape[-1]

Read the per-token Q/K dimension from the input. Used in the scaling factor below.

EXECUTION STATE
Q.shape = (T=6, d_k=4)
Q.shape[-1] = 4 - the LAST axis
13scores = Q @ K.T / np.sqrt(d_k)

Compute all T x T pairwise dot products in one matrix multiply, then divide by sqrt(d_k). The division is the SCALING in 'scaled dot-product attention' - it keeps variance ~1 regardless of d_k.

EXECUTION STATE
Q @ K.T = (T, d_k) @ (d_k, T) = (T, T) - score[i, j] = Q[i] dot K[j]
→ why sqrt(d_k)? = Without it, dot-product variance grows linearly with d_k; large d_k pushes softmax into saturating regions where gradients vanish. Vaswani et al. 2017.
Example: scores[0] = [1.524, 2.145, -1.174, 0.797, 0.142, -0.279]
14attn = softmax_rowwise(scores)

Convert scores to a row-stochastic attention map. Each row is a probability distribution over the T positions.

EXECUTION STATE
attn.shape = (T, T)
Example: attn[0] = [0.2611, 0.4863, 0.0176, 0.1263, 0.0656, 0.043]
→ row sums to 1 = 0.2611 + 0.4863 + 0.0176 + ... = 1.000
→ which key won? = Position 1 with 48.6% - query 0 attends most to key 1
15out = attn @ V

Weighted average of all value vectors. Row i of out is the attn[i]-weighted combination of all rows of V.

EXECUTION STATE
(T, T) @ (T, d_v) = (T, d_v) = Output has same shape as V
Example: out[0] = [-0.781, -0.694, -0.437, 0.121]
→ interpretation = out[0] is mostly V[1] (weight 0.486), plus a bit of V[0] (0.261), small contributions from V[3..5]
16return out, attn

Return both the output AND the attention map - the latter is what we visualise to interpret the model.

20np.random.seed(0)

Lock RNG.

21T, d_k = 6, 4

Toy: 6 tokens, 4-dimensional Q/K/V. Real RUL uses T=30, d_k=64 (= 512/8 heads).

EXECUTION STATE
T = 6 - sequence length
d_k = 4 - per-head Q/K dimension
22Q = np.random.randn(T, d_k).astype(np.float32)

Random query matrix.

23K = np.random.randn(T, d_k).astype(np.float32)

Random key matrix.

24V = np.random.randn(T, d_k).astype(np.float32)

Random value matrix.

26out, attn = attention(Q, K, V)

One call returns both the contextualised output and the interpretable attention map.

EXECUTION STATE
out.shape = (6, 4)
attn.shape = (6, 6)
28print("scores[0] :", ...)

Inspect the raw scaled scores for query 0.

EXECUTION STATE
Output = scores[0]: [1.524, 2.145, -1.174, 0.797, 0.142, -0.279]
29print("attn[0] :", ...)

After softmax: a probability distribution.

EXECUTION STATE
Output = attn[0]: [0.2611, 0.4863, 0.0176, 0.1263, 0.0656, 0.043]
30print("attn[0].sum() :", round(attn[0].sum(), 6))

Sanity check - softmax outputs sum to 1.

EXECUTION STATE
Output = attn[0].sum(): 1.0
31print("out[0] :", ...)

Final contextualised output for query 0 - a 4-vector.

EXECUTION STATE
Output = out[0]: [-0.781, -0.694, -0.437, 0.121]
15 lines without explanation
1import numpy as np
2
3# ----- Scaled dot-product attention from scratch -----
4def softmax_rowwise(x: np.ndarray) -> np.ndarray:
5    shifted = x - x.max(axis=-1, keepdims=True)    # numerical stability
6    expd = np.exp(shifted)
7    return expd / expd.sum(axis=-1, keepdims=True)
8
9
10def attention(Q, K, V):
11    """Q, K: (T, d_k). V: (T, d_v). Returns (T, d_v) and (T, T) attention map."""
12    d_k = Q.shape[-1]
13    scores = Q @ K.T / np.sqrt(d_k)                # (T, T)
14    attn   = softmax_rowwise(scores)               # (T, T), rows sum to 1
15    out    = attn @ V                              # (T, d_v)
16    return out, attn
17
18
19# ----- Run on a tiny synthetic batch -----
20np.random.seed(0)
21T, d_k = 6, 4
22Q = np.random.randn(T, d_k).astype(np.float32)
23K = np.random.randn(T, d_k).astype(np.float32)
24V = np.random.randn(T, d_k).astype(np.float32)
25
26out, attn = attention(Q, K, V)
27
28print("scores[0]      :", (Q @ K.T / np.sqrt(d_k))[0].round(3).tolist())
29print("attn[0]        :", attn[0].round(4).tolist())
30print("attn[0].sum()  :", round(attn[0].sum(), 6))
31print("out[0]         :", out[0].round(3).tolist())
32
33# scores[0]      : [1.524, 2.145, -1.174, 0.797, 0.142, -0.279]
34# attn[0]        : [0.2611, 0.4863, 0.0176, 0.1263, 0.0656, 0.043]
35# attn[0].sum()  : 1.0
36# out[0]         : [-0.781, -0.694, -0.437, 0.121]

PyTorch: scaled_dot_product_attention

PyTorch 2.0+ ships a single fused kernel for attention, optionally backed by Flash-Attention on CUDA. For multi-head usage the nn.MultiheadAttention module wraps the projections and head-splitting boilerplate.

F.scaled_dot_product_attention + nn.MultiheadAttention
🐍attention_torch.py
1import torch

Top-level PyTorch.

2import torch.nn.functional as F

Functional API. F.scaled_dot_product_attention is the modern, fused, optionally Flash-Attention implementation.

4torch.manual_seed(0)

Determinism.

5B, T, d_k = 1, 6, 4

Same toy shape from the NumPy block plus a batch dimension.

8Q = torch.randn(B, T, d_k)

Random query tensor. Note the leading batch axis - PyTorch's attention API expects (B, T, d).

EXECUTION STATE
Q.shape = torch.Size([1, 6, 4])
9K = torch.randn(B, T, d_k)

Random key tensor.

10V = torch.randn(B, T, d_k)

Random value tensor.

14out = F.scaled_dot_product_attention(Q, K, V)

PyTorch 2.0+ ships a fused, optionally Flash-Attention implementation. Equivalent to softmax(Q @ K^T / sqrt(d_k)) @ V but typically 3-10x faster on GPU and more memory-efficient.

EXECUTION STATE
F.scaled_dot_product_attention(Q, K, V, ...) = Optional kwargs: attn_mask (causal / padding), dropout_p, is_causal. Replaces the 6-line NumPy version.
out.shape = torch.Size([1, 6, 4])
16print("Q.shape :", tuple(Q.shape))

Verify input shape.

EXECUTION STATE
Output = Q.shape : (1, 6, 4)
17print("out.shape:", tuple(out.shape))

Output is the same shape as Q (when V has the same last-dim).

EXECUTION STATE
Output = out.shape: (1, 6, 4)
21import torch.nn as nn

For the multi-head wrapper.

23mha = nn.MultiheadAttention(embed_dim=64, num_heads=8, batch_first=True)

Multi-head attention as one Module. Internally projects Q/K/V to (B, T, num_heads, d_head), runs attention per head in parallel, concatenates, then projects back to embed_dim.

EXECUTION STATE
embed_dim=64 = Input/output feature dim. Must equal num_heads * d_head.
num_heads=8 = Parallel attention heads. Each sees its own (B, T, 8) Q/K/V.
batch_first=True = (B, T, F) shape - same trap as nn.LSTM.
28x = torch.randn(2, 30, 64)

Realistic input shape: 2 engines, 30 cycles, 64-dim feature space (after the Conv1D frontend).

EXECUTION STATE
x.shape = torch.Size([2, 30, 64])
29y, attn_weights = mha(x, x, x, need_weights=True)

Self-attention - same tensor used for Q, K, AND V. need_weights=True returns the (B, T, T) attention map for visualisation.

EXECUTION STATE
self-attention = Q = K = V means the layer learns to relate every cycle to every other cycle - exactly the long-range dependency capability we lacked with Conv1D.
y.shape = torch.Size([2, 30, 64])
attn_weights.shape = torch.Size([2, 30, 30])
31print("input x :", tuple(x.shape))

Confirm input.

EXECUTION STATE
Output = input x : (2, 30, 64)
32print("output y :", tuple(y.shape))

Same shape as input - attention preserves (B, T, F) by construction.

EXECUTION STATE
Output = output y : (2, 30, 64)
33print("attn_weights :", tuple(attn_weights.shape))

(B, T, T) - one TxT attention map per batch element. Each row sums to 1.

EXECUTION STATE
Output = attn_weights : (2, 30, 30)
16 lines without explanation
1import torch
2import torch.nn.functional as F
3
4torch.manual_seed(0)
5B, T, d_k = 1, 6, 4
6
7# Same shapes as the NumPy version, with a leading batch axis
8Q = torch.randn(B, T, d_k)
9K = torch.randn(B, T, d_k)
10V = torch.randn(B, T, d_k)
11
12# Single line - PyTorch 2.0+ ships an optimised kernel (Flash-Attention on
13# CUDA, plain math fallback on CPU)
14out = F.scaled_dot_product_attention(Q, K, V)
15
16print("Q.shape  :", tuple(Q.shape))    # (1, 6, 4)
17print("out.shape:", tuple(out.shape))  # (1, 6, 4)
18
19
20# ----- Multi-head attention via nn.MultiheadAttention -----
21import torch.nn as nn
22
23mha = nn.MultiheadAttention(
24    embed_dim=64, num_heads=8, batch_first=True   # 64 / 8 = 8 dims per head
25)
26
27# (B, T, embed_dim)
28x = torch.randn(2, 30, 64)
29y, attn_weights = mha(x, x, x, need_weights=True)
30
31print("input  x        :", tuple(x.shape))           # (2, 30, 64)
32print("output y        :", tuple(y.shape))           # (2, 30, 64)
33print("attn_weights    :", tuple(attn_weights.shape)) # (2, 30, 30)
Flash-Attention. On modern NVIDIA GPUs, F.scaled_dot_product_attention dispatches to the Flash-Attention kernel automatically when shapes are favourable, producing 3-10x speedups vs. the naive softmax(Q K^T) @ V. No code changes — PyTorch picks the fastest backend.

Attention Beyond RUL

DomainWhat attends to whatFamous architecture
RUL prediction (this book)Cycles within a 30-cycle windowCNN-BiLSTM-Attention
Language modellingTokens within a context windowGPT, BERT, T5
Image recognitionPatches within an imageViT, Swin Transformer
Speech recognitionAudio framesWhisper, Conformer
Protein foldingResidues within a chainAlphaFold 2
Music generationNotes within a phraseMusic Transformer
RecommendationItems in a user historySASRec
Medical imagingVoxels in a 3-D scanTransUNet

The same five-line attention computation underpins all of them. The mathematics is exact; what changes is the modality and the size of the (T, d) tensor.

The Three Pitfalls

Pitfall 1: Forgetting the scale. Drop /dk/ \sqrt{d_k} and softmax saturates with increasing dkd_k. The model trains for a few epochs then plateaus because gradients through softmax are nearly zero.
Pitfall 2: Softmax overflow. Without the max-subtract trick, np.exp(1000) = inf. PyTorch's functional softmax handles it; if you implement softmax yourself, always shift first.
Pitfall 3: Interpreting attention as explanation. High attention does NOT prove a causal relationship between two tokens. The model can attend to a position for many reasons that have nothing to do with feature importance. Treat attention maps as a diagnostic, not a proof.
The point. Self-attention lets every cycle see every other cycle in O(T^2) time, learn which to weight, and produce a context-rich output. Combined with Conv1D (local patterns) and BiLSTM (sequential dynamics), it gives our backbone its long-range modelling power.

Takeaway

  • Self-attention is differentiable lookup. Each query selects a weighted blend of values via similarity to keys.
  • The whole formula is one line. softmax(QK/dk)V\text{softmax}(QK^{\top}/\sqrt{d_k})\, V — three matrix multiplies plus a softmax.
  • Scale by sqrt(d_k). Keeps softmax in its well-behaved regime regardless of the per-head dimension.
  • Multi-head runs H attentions in parallel. Each head sees a lower-dim slice and learns its own type of relationship.
  • PyTorch ships a fused kernel. F.scaled_dot_product_attention dispatches to Flash-Attention automatically on CUDA — never write the softmax(Q K^T)/sqrt(d) yourself in production code.
Loading comments...