Chapter 1
25 min read
Section 2 of 17

Scaled Dot-Product Attention

Scaled Dot-Product Attention

Learning Objectives

By the end of this chapter, you will be able to:

  1. Explain why RNNs and LSTMs hit a fundamental bottleneck for long sequences and why a parallel, all-pairs mechanism was needed.
  2. Derive the scaled dot-product attention formula softmax ⁣(QKdk)V\text{softmax}\!\bigl(\tfrac{QK^\top}{\sqrt{d_k}}\bigr)V from first principles, understanding every symbol and operation.
  3. Prove why the 1/dk1/\sqrt{d_k} scaling factor is necessary and what happens without it.
  4. Compute attention weights and outputs by hand for our shared example sentence “The cat sat on mat”.
  5. Implement a complete, runnable Python class that you can use to simulate scaled dot-product attention on any input.
  6. Connect this foundational mechanism to its extensions: multi-head attention, Flash Attention, KV-cache, and positional encodings.
Where this appears: Scaled dot-product attention, or a close variant of it, is the computational core of most transformer-based models — including GPT-4, Claude, Gemini, LLaMA, Stable Diffusion, AlphaFold, Codex, and ViT. These systems compute this formula (or an optimized equivalent like Flash Attention) billions of times during inference. Understanding it deeply is the foundation for understanding modern AI architectures.

The Shared Example — Used in Every Chapter

Every mechanism in this book operates on the exact same sentence, the same matrices, and the same parameters. This allows you to directly compare what each mechanism does differently — the only thing that changes is the attention computation itself.

The Sentence and Tokens

ParameterValue
Sentence"The cat sat on mat"
Tokens[The, cat, sat, on, mat]
NN5 tokens
dmodeld_{\text{model}}4 (kept tiny so every number is readable)
dkd_k4 (in this chapter, single-head: dk=dmodel=4d_k = d_{\text{model}} = 4)
HH2 heads (used from Chapter 2 onward; per-head dim dk=dmodel/H=2d_k = d_{\text{model}} / H = 2)

Why “mat” and not “the mat”?

The sentence is intentionally “The cat sat on mat” (dropping the second article) to keep the token count at exactly 5. This produces clean 5×55 \times 5 attention matrices that are easy to compute by hand and visualize as heatmaps. Every chapter uses this same 5-token sentence so the matrices stay comparable.

The Query, Key, and Value Matrices

QQ, KK, and VV are the three fundamental matrices in every attention mechanism. In practice they are produced by learned linear projections of the input embeddings: Q=XWQQ = XW^Q, K=XWKK = XW^K, V=XWVV = XW^V, where XRN×dmodelX \in \mathbb{R}^{N \times d_{\text{model}}} is the input embedding matrix, WQ,WKRdmodel×dkW^Q, W^K \in \mathbb{R}^{d_{\text{model}} \times d_k} are the learned query and key projections, and WVRdmodel×dvW^V \in \mathbb{R}^{d_{\text{model}} \times d_v} is the learned value projection (in practice dv=dkd_v = d_k, but the formalism allows them to differ). Here we fix them so all 15 mechanisms start from the same point.

QQ (Query Matrix) — 5×45 \times 4 — each row encodes “what is this token looking for?”

dim-0dim-1dim-2dim-3
The1.00.01.00.0
cat0.02.00.01.0
sat1.01.01.00.0
on0.00.01.01.0
mat1.00.00.01.0

KK (Key Matrix) — 5×45 \times 4 — each row encodes “what information does this token advertise?”

dim-0dim-1dim-2dim-3
The0.01.00.01.0
cat1.00.01.00.0
sat1.01.00.00.0
on0.00.01.01.0
mat1.00.00.50.5

VV (Value Matrix) — 5×45 \times 4 — each row is the actual content that gets retrieved when a token is attended to

dim-0dim-1dim-2dim-3
The1.00.00.00.0
cat0.01.00.00.0
sat0.00.01.00.0
on0.00.00.01.0
mat0.50.50.50.5

What Every Chapter Shows

For each mechanism you will find four sections:

  1. Problem & Intuition — what limitation this mechanism was designed to fix
  2. The Math — the exact formula with every variable defined
  3. Step-by-Step Calculation — the formula applied to “The” (row 0) using real numbers
  4. Python Code — a clean, runnable class implementation using only NumPy

Every chapter ends with the full Attention Weight Matrix (5×55 \times 5) and Output Matrix (5×45 \times 4) so you can immediately see how the pattern changes from one mechanism to the next.

How to use this book

By keeping the input identical across all 15 chapters, you can isolate and compare exactly what each mechanism contributes. When you see that causal masking zeros out future tokens, or that RoPE rotates the Q and K vectors, or that Flash Attention produces identical outputs with different memory access patterns — these differences become concrete and unmistakable, not abstract.

The Real Problem

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). “Attention Is All You Need.” Advances in Neural Information Processing Systems (NeurIPS), 30.

To understand why scaled dot-product attention exists, you must first understand what it replaced and why that predecessor was breaking under the weight of real-world language.

The RNN Bottleneck

Before 2017, the dominant architecture for sequence modeling was the Recurrent Neural Network (RNN) and its variants (LSTM, GRU). These models processed tokens sequentially — one at a time, left to right — accumulating information into a fixed-size hidden state vector htRdh_t \in \mathbb{R}^d.

This created three critical problems:

  1. The information bottleneck. The entire history of a 10,000-token document had to be compressed into a single vector hth_t of, say, 512 dimensions. By the time the model reached the end of a long document, the information from the beginning had been overwritten or diluted beyond recognition. In seq2seq translation (Sutskever et al., 2014), the encoder's final hidden state was the only bridge between the source and target language — a catastrophic bottleneck.
  2. No parallelism. Because hth_t depended on ht1h_{t-1}, every timestep had to wait for the previous one to finish. You could not process tokens 1 through 10,000 in parallel. Training on long sequences was agonizingly slow, even on GPUs designed for massive parallelism.
  3. Vanishing and exploding gradients. During backpropagation through time (BPTT), gradients had to flow through TT multiplicative steps. Even with LSTM gating, gradients for dependencies spanning hundreds of tokens degraded to near zero, making it nearly impossible to learn long-range patterns like coreference (“The cat... it...” separated by 50 tokens).

The Bahdanau breakthrough (2014)

Bahdanau, Cho, & Bengio introduced the first attention mechanism for neural machine translation. Instead of relying on a single bottleneck vector, their decoder could “look back” at every encoder hidden state and compute a weighted combination. This was transformative — but it was still layered on top of an RNN, so the sequential processing bottleneck remained. The attention scores were computed using a small feedforward network (additive attention), which was slower than a simple dot product.

The Fundamental Question

Vaswani et al. asked a radical question: what if we removed the recurrence entirely? What if, instead of processing tokens sequentially, every token could directly attend to every other token in a single parallel operation?

The answer required a mechanism with three properties:

  • All-pairs comparison — every token should be able to compute its relevance to every other token in the sequence, in parallel.
  • Differentiable selection — the model should learn which tokens are relevant via soft weights (probabilities), not hard selection, so gradients can flow.
  • Content-based addressing — relevance should be determined by what the tokens mean (their representations), not by their position alone.

Scaled dot-product attention satisfies all three. It is the mechanism that made the Transformer possible.


From Intuition to Mathematics

The Query-Key-Value Intuition

The central metaphor is a library lookup. Imagine you walk into a library with a question in mind (your query). Each book on the shelf has a label on its spine (its key) and content inside (its value). You scan all the spine labels, judge which ones are most relevant to your question, and then read a weighted blend of the most relevant books.

ConceptIn AttentionIn Our Example
Query (Q)“What am I looking for?” Each token broadcasts a question about what context it needs.QThe=[1,0,1,0]Q_{\text{The}} = [1, 0, 1, 0] — “The” is looking for tokens with activity in dims 0 and 2
Key (K)“What do I have to offer?” Each token advertises the kind of information it carries.Kcat=[1,0,1,0]K_{\text{cat}} = [1, 0, 1, 0] — “cat” advertises presence in dims 0 and 2
Value (V)“Here is my actual content.” The information that gets retrieved when a token is attended to.Vcat=[0,1,0,0]V_{\text{cat}} = [0, 1, 0, 0] — the actual semantic content of “cat”

The critical insight is that Q and K live in the same space so their dot product is meaningful as a similarity measure, while V can live in a different space — it carries the content that the query wants to retrieve.

Dot Product as Similarity

Why use the dot product to measure relevance? Consider two vectors q=[1,0,1,0]\mathbf{q} = [1, 0, 1, 0] and k=[1,0,1,0]\mathbf{k} = [1, 0, 1, 0]:

qk=1×1+0×0+1×1+0×0=2\mathbf{q} \cdot \mathbf{k} = 1 \times 1 + 0 \times 0 + 1 \times 1 + 0 \times 0 = 2

The dot product is large because the vectors “agree” — they are active in the same dimensions. Now compare with k=[0,1,0,1]\mathbf{k}' = [0, 1, 0, 1]:

qk=1×0+0×1+1×0+0×1=0\mathbf{q} \cdot \mathbf{k}' = 1 \times 0 + 0 \times 1 + 1 \times 0 + 0 \times 1 = 0

The dot product is zero because the vectors are orthogonal — they have nothing in common. This is exactly what we want: the dot product naturally measures the alignment between what the query is looking for and what the key advertises. Geometrically, qk=qkcosθ\mathbf{q} \cdot \mathbf{k} = \|\mathbf{q}\| \|\mathbf{k}\| \cos\theta, so it captures both magnitude and direction.


The Mathematical Definition

The complete formula for scaled dot-product attention is:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

This single line is the most important equation in modern deep learning. Let us dissect every piece.

Symbol-by-Symbol Breakdown

SymbolShapeMeaning
QQN ⁣× ⁣dkN \!\times\! d_kQuery matrix. Row ii is the query vector for token ii. In our example: 5 ⁣× ⁣45 \!\times\! 4.
KKN ⁣× ⁣dkN \!\times\! d_kKey matrix. Row jj is the key vector for token jj. Same shape as QQ so dot products are valid.
VVN ⁣× ⁣dvN \!\times\! d_vValue matrix. Row jj is the content vector for token jj. Can have a different dimension than dkd_k, though in practice dv=dkd_v = d_k.
KK^\topdk ⁣× ⁣Nd_k \!\times\! NTranspose of KK, so that QKQK^\top produces an N ⁣× ⁣NN \!\times\! N matrix of pairwise similarities.
QKQK^\topN ⁣× ⁣NN \!\times\! NThe raw score matrix. Entry (i,j)(i,j) is Qi ⁣ ⁣KjQ_i \!\cdot\! K_j — how much token ii's query aligns with token jj's key.
dk\sqrt{d_k}scalarScaling factor. In our example, 4=2\sqrt{4}=2. In GPT-2, 64=8\sqrt{64}=8. In GPT-3, 128 ⁣ ⁣11.3\sqrt{128}\!\approx\!11.3.
softmax\text{softmax}row-wiseApplied independently to each row. Converts raw scores into a probability distribution: softmax(zi)=ezi ⁣/ ⁣jezj\text{softmax}(z_i) = e^{z_i}\!/\!\sum_j e^{z_j}. Each row sums to 1.
OutputN ⁣× ⁣dvN \!\times\! d_vThe context-enriched representation. Each row is a weighted combination of all value vectors, where weights reflect learned relevance.

What the Formula Says in Plain English

The formula performs four operations in sequence:

  1. Compare every query with every key via dot product (QKQ K^\top), producing an N×NN \times N matrix of raw similarity scores.
  2. Scale the scores by 1/dk1/\sqrt{d_k} to prevent gradient saturation in the softmax.
  3. Normalize each row with softmax to convert scores into attention weights (probabilities that sum to 1).
  4. Aggregate by multiplying the attention weights with the value matrix, producing a context-aware output for each token.

In short: every token asks “who is relevant to me?” (via Q · K), computes how relevant (via softmax), and collects a weighted blend of answers (via V). The result is that each token's output representation is enriched with contextual information from all other tokens, with more weight given to the most relevant ones.


Why Scale? The Variance Argument

This is the most commonly asked question about the formula: why divide by dk\sqrt{d_k} and not some other constant? The answer comes from a simple variance analysis.

The Mathematical Proof

Assume each element of QQ and KK is drawn independently from a distribution with mean 0 and variance 1. The dot product of two such vectors is:

QiKj=m=1dkqmkmQ_i \cdot K_j = \sum_{m=1}^{d_k} q_m \cdot k_m

Since each qmq_m and kmk_m are independent with mean 0 and variance 1:

  • E[qmkm]=E[qm]E[km]=00=0\mathbb{E}[q_m \cdot k_m] = \mathbb{E}[q_m] \cdot \mathbb{E}[k_m] = 0 \cdot 0 = 0
  • Var(qmkm)=E[qm2]E[km2]=11=1\text{Var}(q_m \cdot k_m) = \mathbb{E}[q_m^2] \cdot \mathbb{E}[k_m^2] = 1 \cdot 1 = 1

The dot product is a sum of dkd_k independent terms, each with variance 1, so by the additivity of variance:

Var(QiKj)=dk\text{Var}(Q_i \cdot K_j) = d_k

When dk=512d_k = 512, the standard deviation of the dot product is 51222.6\sqrt{512} \approx 22.6. This means individual scores can easily reach values like ±40\pm 40 or more.

Now consider what softmax does with such extreme values. If one score is 40 and the rest are near 0:

softmax([40,0,0,0,0])[1.000,0.000,0.000,0.000,0.000]\text{softmax}([40, 0, 0, 0, 0]) \approx [1.000, 0.000, 0.000, 0.000, 0.000]

The output is essentially a one-hot vector. The gradient of softmax at these saturation points is near zero (softmax(zi)/zj0\partial \text{softmax}(z_i) / \partial z_j \to 0), which means the model cannot learn to adjust these scores — learning stalls.

The fix is elegant: dividing by dk\sqrt{d_k} rescales the variance back to 1:

Var ⁣(QiKjdk)=dkdk=1\text{Var}\!\left(\frac{Q_i \cdot K_j}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1

Now the scores have a standard deviation of ~1 regardless of the dimension. Softmax receives moderate inputs, produces smooth (non-saturated) probability distributions, and gradients flow properly for learning. This is why we scale by exactly dk\sqrt{d_k} — not 2, not dkd_k, not a learned parameter. It is the mathematically correct normalization to keep the softmax in its useful operating regime.

Additive vs. multiplicative attention

Vaswani et al. note in their paper that for small dkd_k, additive attention (Bahdanau style, using a learned feedforward network) and dot-product attention perform similarly. But for large dkd_k, unscaled dot-product attention degrades significantly. The scaling factor closes this gap, making dot-product attention both faster (matrix multiply vs. feedforward) and equally effective at any dimension.

Interactive: Scaling Explorer

Use the interactive explorer below to see exactly what happens to softmax outputs as you increase dkd_k. Toggle between “With Scaling” and “Without Scaling” to see how the distribution collapses without the dk\sqrt{d_k} correction.

Loading scaling explorer...

Step-by-Step Calculation for “The” (Row 0)

We now trace through every arithmetic step for the first token “The” to see exactly how scaled dot-product attention builds its output vector. Every number here can be verified with the Python class at the bottom of this chapter.

Step 1 — Raw Dot Products: S0,:=Q0KS_{0,:} = Q_0 \cdot K^\top

The query vector for “The” is Q0=[1,0,1,0]Q_0 = [1, 0, 1, 0]. We compute its dot product with every key vector:

PairCalculationRaw Score
QTheKTheQ_{\text{The}} \cdot K_{\text{The}}1 ⁣× ⁣0+0 ⁣× ⁣1+1 ⁣× ⁣0+0 ⁣× ⁣11 \!\times\! 0 + 0 \!\times\! 1 + 1 \!\times\! 0 + 0 \!\times\! 10.0
QTheKcatQ_{\text{The}} \cdot K_{\text{cat}}1 ⁣× ⁣1+0 ⁣× ⁣0+1 ⁣× ⁣1+0 ⁣× ⁣01 \!\times\! 1 + 0 \!\times\! 0 + 1 \!\times\! 1 + 0 \!\times\! 02.0
QTheKsatQ_{\text{The}} \cdot K_{\text{sat}}1 ⁣× ⁣1+0 ⁣× ⁣1+1 ⁣× ⁣0+0 ⁣× ⁣01 \!\times\! 1 + 0 \!\times\! 1 + 1 \!\times\! 0 + 0 \!\times\! 01.0
QTheKonQ_{\text{The}} \cdot K_{\text{on}}1 ⁣× ⁣0+0 ⁣× ⁣0+1 ⁣× ⁣1+0 ⁣× ⁣11 \!\times\! 0 + 0 \!\times\! 0 + 1 \!\times\! 1 + 0 \!\times\! 11.0
QTheKmatQ_{\text{The}} \cdot K_{\text{mat}}1 ⁣× ⁣1+0 ⁣× ⁣0+1 ⁣× ⁣0.5+0 ⁣× ⁣0.51 \!\times\! 1 + 0 \!\times\! 0 + 1 \!\times\! 0.5 + 0 \!\times\! 0.51.5

Interpretation: “The” has the highest raw similarity with “cat” (score = 2.0) because QThe=[1,0,1,0]Q_{\text{The}} = [1, 0, 1, 0] and Kcat=[1,0,1,0]K_{\text{cat}} = [1, 0, 1, 0] are identical — perfect alignment. The lowest similarity is with itself (score = 0.0) because KThe=[0,1,0,1]K_{\text{The}} = [0, 1, 0, 1] is orthogonal to its query.

Step 2 — Scaling: S0,:/dkS_{0,:} / \sqrt{d_k}

We divide each raw score by 4=2.0\sqrt{4} = 2.0:

TokenRaw ScoreScaled Score
The0.0 / 2.00.000
cat2.0 / 2.01.000
sat1.0 / 2.00.500
on1.0 / 2.00.500
mat1.5 / 2.00.750

In our small example (dk=4d_k = 4), scaling reduces scores by half. In a production model with dk=128d_k = 128, scores would be divided by ~11.3, a much more dramatic reduction.

Step 3 — Softmax: A0,:=softmax(S0,:)A_{0,:} = \text{softmax}(S_{0,:})

The softmax function converts the scaled scores into a probability distribution. For a vector z=[z1,,zN]z = [z_1, \ldots, z_N], the softmax of element ii is ezi/jezje^{z_i} / \sum_{j} e^{z_j}.

Computing each exponential:

TokenScaled ScoreExponential
The0.000e0.000=1.000e^{0.000} = 1.000
cat1.000e1.000=2.718e^{1.000} = 2.718
sat0.500e0.500=1.649e^{0.500} = 1.649
on0.500e0.500=1.649e^{0.500} = 1.649
mat0.750e0.750=2.117e^{0.750} = 2.117

Sum of exponentials: 1.000+2.718+1.649+1.649+2.117=9.1331.000 + 2.718 + 1.649 + 1.649 + 2.117 = 9.133

Dividing each exponential by the sum:

TokenAttention WeightPercentage
The0.109510.95%
cat0.297629.76%
sat0.180518.05%
on0.180518.05%
mat0.231823.18%

Interpretation: “The” pays 29.76% of its attention to “cat” — the token most aligned with its query. But notice that attention is not concentrated on a single token. “mat” also receives significant weight (23.18%) because its key partially overlaps with the query. This soft, distributed attention is what allows the mechanism to capture nuanced relationships.

Step 4 — Weighted Sum of Values: O0=A0,:VO_0 = A_{0,:} \cdot V

Each output dimension is the weighted average of that dimension across all value vectors:

ComponentCalculation
weighted sum0.1095×VThe+0.2976×Vcat+0.1805×Vsat+0.1805×Von+0.2318×Vmat0.1095 \times V_{\text{The}} + 0.2976 \times V_{\text{cat}} + 0.1805 \times V_{\text{sat}} + 0.1805 \times V_{\text{on}} + 0.2318 \times V_{\text{mat}}

Expanding each value vector:

📝weighted_sum_trace.txt
10.1095 × [1.0, 0.0, 0.0, 0.0] = [0.1095, 0.0000, 0.0000, 0.0000]  (The)
20.2976 × [0.0, 1.0, 0.0, 0.0] = [0.0000, 0.2976, 0.0000, 0.0000]  (cat)
30.1805 × [0.0, 0.0, 1.0, 0.0] = [0.0000, 0.0000, 0.1805, 0.0000]  (sat)
40.1805 × [0.0, 0.0, 0.0, 1.0] = [0.0000, 0.0000, 0.0000, 0.1805]  (on)
50.2318 × [0.5, 0.5, 0.5, 0.5] = [0.1159, 0.1159, 0.1159, 0.1159]  (mat)
6─────────────────────────────────────────────────────────────────────
7Sum:                             [0.2254, 0.4135, 0.2964, 0.2964]  ← Output

Final output for “The”: O0=[0.2254,0.4135,0.2964,0.2964]O_0 = [0.2254, 0.4135, 0.2964, 0.2964]

Notice that dim-1 has the largest value (0.4135) because “cat” — the most-attended token — contributes all its weight to dim-1 (Vcat=[0,1,0,0]V_{\text{cat}} = [0, 1, 0, 0]). The output for “The” has been enriched with the semantic content of “cat”, reflecting the strong query-key alignment between these two tokens.


Interactive: Attention Pipeline

Select any token to trace the full attention computation step by step. Click through the four stages — dot products, scaling, softmax, and weighted output — to see how attention flows from the query to the final output.

Loading pipeline visualizer...

Full Attention Weights and Output

Below is the complete attention weight matrix and its corresponding output for all five tokens. Hover over any cell in the heatmap to see the full computation for that query-key pair.

Loading heatmap...

Interpreting the Heatmap

Several patterns emerge from the attention weight matrix:

  • “cat” attends most to “The” (0.4026) — the determiner that typically precedes it. Qcat=[0,2,0,1]Q_{\text{cat}} = [0, 2, 0, 1] has strong overlap with KThe=[0,1,0,1]K_{\text{The}} = [0, 1, 0, 1] (dot product = 3.0, the largest in the entire matrix).
  • “on” attends most to itself (0.3137) — Qon=Kon=[0,0,1,1]Q_{\text{on}} = K_{\text{on}} = [0, 0, 1, 1]. When the query and key of the same token are aligned, self-attention is high. This is common for function words that carry positional rather than semantic weight.
  • “mat” distributes attention nearly uniformly (all weights near 0.19-0.24) — its query [1,0,0,1][1, 0, 0, 1] partially matches many keys without strongly preferring one.

Reading attention patterns

In practice, attention weights are not easily interpretable as “this token is related to that token.” They reflect the learned projection space (Q, K), which may encode abstract features unrelated to surface-level word meaning. Attention visualization is useful for debugging and building intuition, but should not be over-interpreted as causal explanation (Jain & Wallace, 2019).

Applications Across Domains

Scaled dot-product attention is not limited to language. The same formula powers breakthroughs across every domain that involves sequences, sets, or structured data.

Natural Language Processing

In GPT-4, Claude, and LLaMA, each layer applies scaled dot-product attention so that every token can gather context from every other token. When the model processes “The cat sat on the mat because it was tired,” the attention mechanism for “it” assigns high weight to “cat” — resolving the pronoun coreference. Without attention, an RNN would need to carry the “cat” information through every intermediate hidden state.

Computer Vision

In Vision Transformers (ViT; Dosovitskiy et al., 2021), an image is split into 16×16 patches. Each patch becomes a token, and scaled dot-product attention lets every patch attend to every other patch. A patch containing a dog's eye can attend to the patch containing its tail — capturing long-range spatial relationships that CNNs can only reach through many stacked convolutional layers.

Code Generation

In Codex and GitHub Copilot, attention enables the model to connect a function's name (token 1) with its return type (token 50), its docstring (tokens 5-20), and the variable names used inside (tokens 30-100). When generating return total_price, the model's attention assigns high weight to the earlier definition total_price = base_price * quantity — even if that line is hundreds of tokens away.

Scientific Sequence Modeling

In AlphaFold2 (Jumper et al., 2021), attention operates over amino acid residues in a protein sequence. Each residue attends to every other residue, learning which pairs of amino acids are likely to be spatially close in the 3D folded structure — even when they are far apart in the linear sequence. The attention weights effectively learn a proxy for physical contact maps, which is remarkable because this structural information is not given as supervision.

Real-World Scale

To appreciate the scale at which this formula operates, the table below shows illustrative parameters from published model architectures. The “Attention Ops” column is a back-of-the-envelope estimate using Layers×Heads×N2\text{Layers} \times \text{Heads} \times N^2 (the dominant term from QKQK^\top across all heads and layers).

ModelLayersHeadsdkd_kMax Seq LengthAttention Ops (est.)
GPT-3 (175B)96961282,048~39 billion
LLaMA-2 70B80641284,096~86 billion
GPT-4 (speculative)~120~961288,192~770 billion
ViT-Large241664197 patches~15 million
AlphaFold248832~500 residues~96 million
Our example114525

About these numbers

GPT-3 and LLaMA-2 parameters are from their published papers. GPT-4's architecture is not publicly disclosed — that row is speculative and shown in muted text. ViT-Large and AlphaFold2 parameters are from Dosovitskiy et al. (2021) and Jumper et al. (2021) respectively. The “Attention Ops” column counts scalar multiply-adds in the QKQK^\top step only and ignores the value aggregation, softmax, and non-attention layers.

The last row is our toy example — the same formula, just orders of magnitude smaller. This is why understanding the mechanism at small scale transfers directly to understanding production systems.


Connection to Modern Systems

Scaled dot-product attention is the atomic building block. Every subsequent chapter in this book modifies, extends, or optimizes it. Here is how the major variants relate to this foundation.

Multi-Head Attention (Chapter 2)

Instead of running one attention function with dmodeld_{\text{model}}-dimensional keys, multi-head attention runs HH independent scaled dot-product attention operations in parallel, each on a smaller dk=dmodel/Hd_k = d_{\text{model}}/H-dimensional subspace. This allows different heads to learn different types of relationships — one head might learn syntactic dependency, another might learn semantic similarity, another might learn positional patterns. The outputs are concatenated and linearly projected back to dmodeld_{\text{model}}.

Flash Attention (Chapter 13)

Flash Attention (Dao et al., 2022) computes mathematically identical results to scaled dot-product attention, but reorganizes the computation to minimize GPU memory reads/writes. The key insight is that the N×NN \times N attention matrix is never fully materialized in GPU HBM (high-bandwidth memory). Instead, it is computed in tiles that fit in SRAM (on-chip memory), achieving 2-4x speedup and reducing memory from O(N2)O(N^2) to O(N)O(N). The math is the same; only the memory access pattern changes.

KV-Cache Optimization

During autoregressive generation (producing one token at a time), the model computes attention for the new token's query against all previous keys and values. The KV-cache stores the KK and VV matrices from all previous tokens so they don't need to be recomputed from the input embeddings at each step. The per-token attention cost remains O(Ndk)O(N \cdot d_k) (the new query still scores against all NN cached keys), but without the cache the model would need to reproject and recompute attention for all previous tokens from scratch — an O(N2dk)O(N^2 \cdot d_k) cost per generated token. The cache trades O(Ndk)O(N \cdot d_k) memory for this saving. Multi-Query Attention (Chapter 5) and Grouped-Query Attention (Chapter 6) reduce the memory cost by sharing K and V across heads.

Positional Encodings (Chapters 7-9)

Scaled dot-product attention is permutation-invariant — it treats the input as a set, not a sequence. Swapping the order of tokens in the input would produce the same attention weights (just reordered). To inject positional information, the Transformer adds positional encodings to the input embeddings. Chapters 7-9 cover three approaches: learned relative position bias (T5), Rotary Position Embeddings (RoPE, used in LLaMA), and Attention with Linear Biases (ALiBi). Each modifies either the input representations or the attention scores to encode where tokens are, not just what they are.


Complexity Analysis

To understand why attention becomes expensive, do not start with formulas. Start with one simple question:

As the sequence gets longer, how much more work does the model have to do, and how much more memory does it need to hold things while doing that work?

That is all complexity analysis is trying to measure.

  • Time complexity tells us how the amount of work grows.
  • Space complexity tells us how the amount of memory grows.

In attention, both grow quickly because every token can interact with every other token.

A Simple Picture: Students in a Classroom

Imagine a classroom with NN students. Each student wants to decide: “Which other students should I listen to?”

If every student checks every other student, then student 1 checks NN students, student 2 checks NN students, and so on. The total number of checks is:

N×N=N2N \times N = N^2

That is the main time cost.

Now imagine writing every one of those scores on a big board. The board needs NN rows and NN columns, so it stores N2N^2 numbers. That is the main space cost. This is exactly what happens in attention.

Where the Cost Comes From in Attention

In scaled dot-product attention, each token asks: “How relevant is every other token to me?” The model answers that by building a giant score table:

QKQK^\top

If there are NN tokens, this table has shape N×NN \times N. That means computing it takes a lot of work, and storing it takes a lot of memory. That single all-to-all interaction pattern is the reason attention becomes expensive.

Step-by-Step Intuition

Step 1: Compute QKQK^\top — the score matrix. This is the “everyone compares with everyone” step. Each token compares itself with every other token. If there are NN tokens, there are about N2N^2 token pairs. But each comparison is not just one tiny action — it is a dot product over vectors of size dkd_k. So each pair needs about dkd_k small multiply-and-add operations.

Time: O(N2dk)O(N^2 d_k). Space: O(N2)O(N^2).

Intuition: a giant spreadsheet where every row is a token and every column is another token, and each cell stores “how much should I pay attention?”

Step 2: Scaling. After the scores are computed, each one is divided by dk\sqrt{d_k}. This keeps the numbers from getting too large. There are N2N^2 scores, so touching all of them costs O(N2)O(N^2) time. Usually this is done directly on the same score table, so it does not need another giant table — the extra memory cost is O(1)O(1).

Intuition: you already filled the board with numbers, and now you go through and slightly shrink each one. That is work, but not extra board space.

Step 3: Softmax. Each row of the score matrix is turned into normalized attention weights. Each row has NN numbers, and there are NN rows, so the total work is O(N2)O(N^2). The model still stores the full weight table, so memory is still O(N2)O(N^2).

Intuition: the raw scores become “attention percentages.” Same giant board, same size, just cleaner numbers.

Step 4: Multiply weights by VV. The model uses those attention weights to mix the value vectors and produce a new output vector for each token. For each of the NN output tokens, it looks across NN tokens, and each value vector has size dvd_v. So the work is O(N2dv)O(N^2 d_v). The output stores one vector of length dvd_v per token, so the memory needed is only O(Ndv)O(N d_v).

Intuition: each token builds a summary of what it learned from everyone else.

Interactive: Worked Example with 3 Tokens

The steps above describe what happens in words. But to truly understand why each operation costs what it does, nothing beats watching the numbers unfold. The walkthrough below uses 3 tokens with vectors of length 3 and walks through every single multiply-and-add. Click Next to advance.

Part 1: QKᵀ — The Comparison StepStep 1 of 30

The Real Idea

Before we start, let's clear up a common confusion:

One token-to-token comparison is NOT a single check.

It is a small calculation made from many numbers.

Think of it like this: if one token is described by 3 features and another token is described by 3 features, then to compare them you must check feature 1 with feature 1, feature 2 with feature 2, feature 3 with feature 3, then add them together.

So one comparison is really 3 little comparisons plus addition, not one instant action. Let's see this with real numbers.

A Tiny Example

Suppose there are only 4 tokens. The score matrix has 4×4=164 \times 4 = 16 entries. That is small. Easy to compute. Easy to store.

Now suppose there are 1,000 tokens. The score matrix has 1,000×1,000=1,000,0001{,}000 \times 1{,}000 = 1{,}000{,}000 entries.

Now suppose there are 8,192 tokens. The score matrix has 8,192×8,192=67,108,8648{,}192 \times 8{,}192 = 67{,}108{,}864 entries. That is about 67 million scores for just one head in one layer.

This is why attention becomes expensive so quickly: the sequence length grows linearly, but the pairwise interaction table grows quadratically.

What “Bottleneck” Really Means Here

The bottleneck is the part that hurts the most as the input grows. In attention, that bottleneck is the N2N^2 interaction pattern. Both the time to compute all token-to-token interactions and the space to store all those interaction scores become large. So when people say standard attention has an O(N2)O(N^2) bottleneck, they mean:

The cost blows up because every token talks to every token.

Summary Table

OperationWhat It DoesTimeSpace
QKQK^\topCompares everyone with everyoneO(N2dk)O(N^2 d_k)O(N2)O(N^2)
ScalingAdjusts every score onceO(N2)O(N^2)O(1)O(1) extra
SoftmaxTurns scores into attention weightsO(N2)O(N^2)O(N2)O(N^2)
weights × VVUses weights to gather informationO(N2dv)O(N^2 d_v)O(Ndv)O(N d_v)
TotalO(N2dk)O(N^2 d_k)O(N2+Ndv)O(N^2 + N d_v)

The reason the total is dominated by quadratic behavior is simple: the giant N×NN \times N board keeps showing up.

At N=8,192N = 8{,}192 tokens (a typical context window), the attention matrix has 67 million entries — per head, per layer. At 32 heads and 32 layers, that is ~69 billion score computations per forward pass.

Attention becomes expensive because every token compares itself with every other token. If there are NN tokens, that creates an N×NN \times N table of interaction scores. Building that table takes computation, and storing it takes memory. Time complexity measures how the amount of computation grows; space complexity measures how the required memory grows. In standard attention, both are dominated by this all-to-all interaction pattern, which is why the cost scales quadratically with sequence length.

Why this matters for the rest of the book

Linear Attention (Chapter 10) reduces time to O(Ndk2)O(N d_k^2) by avoiding the N×NN \times N matrix entirely. Sliding Window (Chapter 11) restricts each token to a local window of size ww, giving O(Nw)O(Nw). Flash Attention (Chapter 13) keeps O(N2)O(N^2) time but reduces memory to O(N)O(N) via tiling. Each chapter trades something different to escape the quadratic wall.

Python Implementation

Below is a complete, runnable implementation of scaled dot-product attention as a Python class. Click any highlighted line on the right to see its detailed explanation on the left. You can copy the full code and run it with python scaled_dot_product_attention.py — the only dependency is NumPy.

Scaled Dot-Product Attention — Full Implementation
🐍scaled_dot_product_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.

2import math

Python standard library. We use math.sqrt() to precompute the scaling factor.

4class ScaledDotProductAttention

Wraps the mechanism in a reusable class. Every chapter follows this same structure so you can compare implementations side by side.

14def __init__(self, d_k: int)

Constructor. Takes one parameter d_k (dimension of query/key vectors). Stores it and precomputes the scaling factor. Called once when instantiated.

EXECUTION STATE
⬇ input: d_k = 4
19self.d_k = d_k

Store d_k as an instance attribute so other methods can access it (the explain method prints it).

EXECUTION STATE
self.d_k = 4
20self.scale = math.sqrt(d_k)

Precompute √d_k once. This constant divides every dot product in scale_scores(). Computing it once avoids repeated sqrt calls.

EXECUTION STATE
math.sqrt(4) = 2.0
self.scale = 2.0
22def _softmax(self, x) → np.ndarray

Function takes self and x (a matrix of scaled scores, shape 5×5). Applies numerically stable softmax to each row. Returns a matrix of the same shape where each row sums to 1.0.

EXECUTION STATE
⬇ input: x = shape (5, 5) — the full scaled_scores matrix
⬆ returns = np.ndarray (5, 5) — softmax probabilities per row
24x_shifted = x - np.max(x, axis=-1, keepdims=True)

Subtract row-wise max from each element. This prevents exp() overflow: exp(500)=Inf but exp(500-500)=1. Does NOT change the softmax result because the constant cancels in the ratio. All 5 rows shown:

EXECUTION STATE
axis=-1 = operate along the LAST axis (columns within each row). For a 5×5 matrix, this means find the max of each row independently, not the global max.
keepdims=True = keep the reduced axis as size-1 dimension. max returns shape (5,1) instead of (5,), so broadcasting x (5×5) - max (5×1) works correctly — each row subtracts its own max.
── Row 0 (The) ── =
x = [0.000, 1.000, 0.500, 0.500, 0.750]
max(x) = 1.000
x_shifted = [-1.000, 0.000, -0.500, -0.500, -0.250]
── Row 1 (cat) ── =
x = [1.500, 0.000, 1.000, 0.500, 0.250]
max(x) = 1.500
x_shifted = [0.000, -1.500, -0.500, -1.000, -1.250]
── Row 2 (sat) ── =
x = [0.500, 1.000, 1.000, 0.500, 0.750]
max(x) = 1.000
x_shifted = [-0.500, 0.000, 0.000, -0.500, -0.250]
── Row 3 (on) ── =
x = [0.500, 0.500, 0.000, 1.000, 0.500]
max(x) = 1.000
x_shifted = [-0.500, -0.500, -1.000, 0.000, -0.500]
── Row 4 (mat) ── =
x = [0.500, 0.500, 0.500, 0.500, 0.750]
max(x) = 0.750
x_shifted = [-0.250, -0.250, -0.250, -0.250, 0.000]
25exp_x = np.exp(x_shifted)

Exponentiate every element. Because we subtracted the max, the largest value per row is exp(0)=1.0 — no overflow possible. All 5 rows:

EXECUTION STATE
exp_x (row 0 The) = [0.3679, 1.0000, 0.6065, 0.6065, 0.7788]
exp_x (row 1 cat) = [1.0000, 0.2231, 0.6065, 0.3679, 0.2865]
exp_x (row 2 sat) = [0.6065, 1.0000, 1.0000, 0.6065, 0.7788]
exp_x (row 3 on) = [0.6065, 0.6065, 0.3679, 1.0000, 0.6065]
exp_x (row 4 mat) = [0.7788, 0.7788, 0.7788, 0.7788, 1.0000]
26return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each element by its row sum to normalize into probabilities. Each row sums to exactly 1.0. All 5 rows:

EXECUTION STATE
axis=-1 = sum along the last axis — sum each row independently, not the entire matrix.
keepdims=True = sum returns shape (5,1) so exp_x (5×5) / sum (5×1) broadcasts correctly — each row divides by its own sum.
── Row 0 (The) ── =
sum(exp_x) = 0.3679+1.0000+0.6065+0.6065+0.7788 = 3.3597
⬆ return = [0.1095, 0.2976, 0.1805, 0.1805, 0.2318]
── Row 1 (cat) ── =
sum(exp_x) = 1.0000+0.2231+0.6065+0.3679+0.2865 = 2.4840
⬆ return = [0.4026, 0.0898, 0.2442, 0.1481, 0.1153]
── Row 2 (sat) ── =
sum(exp_x) = 0.6065+1.0000+1.0000+0.6065+0.7788 = 3.9919
⬆ return = [0.1519, 0.2505, 0.2505, 0.1519, 0.1951]
── Row 3 (on) ── =
sum(exp_x) = 0.6065+0.6065+0.3679+1.0000+0.6065 = 3.1875
⬆ return = [0.1903, 0.1903, 0.1154, 0.3137, 0.1903]
── Row 4 (mat) ── =
sum(exp_x) = 0.7788+0.7788+0.7788+0.7788+1.0000 = 4.1152
⬆ return = [0.1892, 0.1892, 0.1892, 0.1892, 0.2430]
sum check = all 5 rows sum to 1.0000 ✓
28def compute_scores(self, Q, K) → np.ndarray

Function takes self, Q (5×4 query matrix), and K (5×4 key matrix). Returns Q @ K.T — the 5×5 raw dot-product score matrix.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬆ returns = np.ndarray (5, 5) — pairwise scores
30return Q @ K.T

Matrix multiply Q (5×4) with K transposed (4×5). Entry (i,j) = dot product of query_i with key_j. Measures how much token i wants to attend to token j.

EXECUTION STATE
@ = Python matrix multiplication operator — equivalent to np.matmul(Q, K.T)
.T = NumPy transpose — flips rows and columns. K is (5×4), K.T is (4×5). This makes the matrix multiply valid: (5×4) @ (4×5) → (5×5).
K.T (4×5) =
     The   cat   sat    on   mat
d0  0.0   1.0   1.0   0.0   1.0
d1  1.0   0.0   1.0   0.0   0.0
d2  0.0   1.0   0.0   1.0   0.5
d3  1.0   0.0   0.0   1.0   0.5
⬆ return: Q @ K.T (5×5) =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
32def scale_scores(self, scores) → np.ndarray

Function takes self and scores (5×5 raw score matrix). Divides every element by self.scale=√d_k. Returns the scaled matrix.

EXECUTION STATE
⬇ input: scores (5×5) =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
self.scale = 2.0 (precomputed √4 in __init__)
⬆ returns = np.ndarray (5, 5) — each element ÷ 2.0
34return scores / self.scale

Every score divided by 2.0. This halves all values. In GPT-3 (d_k=128), scores get divided by 11.3 — much more compression.

EXECUTION STATE
⬆ return: scores / 2.0 =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
36def compute_weights(self, scaled_scores) → np.ndarray

Function takes self and scaled_scores (5×5). Applies softmax row-wise via self._softmax(). Returns the attention weight matrix where each row sums to 1.0.

EXECUTION STATE
⬇ input: scaled_scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
38return self._softmax(scaled_scores)

Calls _softmax on the full 5×5 matrix. Each row independently becomes a probability distribution.

EXECUTION STATE
⬆ return: weights =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
40def compute_output(self, weights, V) → np.ndarray

Function takes self, weights (5×5 attention probabilities), and V (5×4 value matrix). Returns weights @ V — each output row is the weighted average of all value vectors.

EXECUTION STATE
⬇ input: weights (5×5) =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = np.ndarray (5, 4) — weighted sum of values
42return weights @ V

Matrix multiply weights (5×5) with V (5×4). Each output row is a blend of all 5 value vectors, weighted by attention. Tokens with higher weights contribute more.

EXECUTION STATE
⬆ return: weights @ V =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
44def forward(self, Q, K, V)

Main entry point. Receives the three matrices and chains all four steps. Returns (weights, output). Hover lines 57-61 for each intermediate.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (weights, output) — shapes (5,5) and (5,4)
57raw_scores = self.compute_scores(Q, K)

Calls compute_scores() → returns Q @ K.T. The 5×5 matrix of all pairwise dot products.

EXECUTION STATE
raw_scores =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
58scaled_scores = self.scale_scores(raw_scores)

Calls scale_scores() → divides every element by 2.0.

EXECUTION STATE
scaled_scores =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
59weights = self.compute_weights(scaled_scores)

Calls compute_weights() → applies softmax row-wise. Each row is now probabilities summing to 1.0.

EXECUTION STATE
weights =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
60output = self.compute_output(weights, V)

Calls compute_output() → returns weights @ V. Each row is the context-enriched representation.

EXECUTION STATE
output =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
61return weights, output

Returns both matrices as a tuple. The caller gets attention weights (5×5) for visualization and the context-enriched output (5×4) for the next layer.

EXECUTION STATE
⬆ return: weights = shape (5, 5)
⬆ return: output = shape (5, 4)
63def explain(self, Q, K, V, tokens, query_idx=0)

Diagnostic function. Takes the same Q, K, V matrices plus token names and which token to trace. Recomputes all intermediates and prints a step-by-step trace. Returns nothing.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬇ input: tokens = ['The', 'cat', 'sat', 'on', 'mat']
⬇ input: query_idx = 0 → will trace token 'The'
⬆ returns = None — prints trace to stdout
69raw_scores = self.compute_scores(Q, K)

Recomputes Q @ K.T (same as forward). Needed because explain() is self-contained.

EXECUTION STATE
raw_scores = shape (5, 5) — same result as in forward()
70scaled_scores = self.scale_scores(raw_scores)

Divide all scores by 2.0.

EXECUTION STATE
scaled_scores = shape (5, 5)
71weights = self.compute_weights(scaled_scores)

Apply softmax row-wise.

EXECUTION STATE
weights = shape (5, 5)
72output = self.compute_output(weights, V)

Weighted sum of value vectors.

EXECUTION STATE
output = shape (5, 4)
74token = tokens[query_idx]

Look up the token name for the index we are tracing.

EXECUTION STATE
tokens[0] = 'The'
token = 'The'
76print Q[query_idx]

Print the query vector for 'The' so the reader sees what this token is 'looking for'.

EXECUTION STATE
Q[0] = [1.0, 0.0, 1.0, 0.0]
78for j, t in enumerate(tokens): — raw scores

Loop over all 5 tokens. For each, print the raw dot product Q[The] · K[token].

LOOP TRACE · 5 iterations
j=0, t='The'
Q[The]·K[The] = 1×0 + 0×1 + 1×0 + 0×1 = 0.0000
j=1, t='cat'
Q[The]·K[cat] = 1×1 + 0×0 + 1×1 + 0×0 = 2.0000
j=2, t='sat'
Q[The]·K[sat] = 1×1 + 0×1 + 1×0 + 0×0 = 1.0000
j=3, t='on'
Q[The]·K[on] = 1×0 + 0×0 + 1×1 + 0×1 = 1.0000
j=4, t='mat'
Q[The]·K[mat] = 1×1 + 0×0 + 1×0.5 + 0×0.5 = 1.5000
82for j, t in enumerate(tokens): — scaled scores

Same loop, now printing scaled scores (raw ÷ 2.0).

LOOP TRACE · 5 iterations
j=0, t='The'
S[The,The] = 0.00 / 2.0 = 0.0000
j=1, t='cat'
S[The,cat] = 2.00 / 2.0 = 1.0000
j=2, t='sat'
S[The,sat] = 1.00 / 2.0 = 0.5000
j=3, t='on'
S[The,on] = 1.00 / 2.0 = 0.5000
j=4, t='mat'
S[The,mat] = 1.50 / 2.0 = 0.7500
86for j, t in enumerate(tokens): — softmax weights

Same loop, now printing attention weights after softmax. '#' bars visualize magnitude.

LOOP TRACE · 5 iterations
j=0, t='The'
A[The,The] = 0.1095 |####|
j=1, t='cat'
A[The,cat] = 0.2976 |###########|
j=2, t='sat'
A[The,sat] = 0.1805 |#######|
j=3, t='on'
A[The,on] = 0.1805 |#######|
j=4, t='mat'
A[The,mat] = 0.2318 |#########|
91print O[token] and sum of weights

Print the final output vector for 'The' and verify weights sum to 1.0.

EXECUTION STATE
O[The] = [0.2254, 0.4135, 0.2964, 0.2964]
sum of weights[0] = 1.000000 ✓
96tokens = [...]

The 5 tokens used in every chapter. 5 tokens gives clean 5×5 attention matrices.

EXECUTION STATE
tokens = ['The', 'cat', 'sat', 'on', 'mat']
98Q = np.array([...])

Query matrix. Each row is what that token 'looks for'. Q[The]=[1,0,1,0] queries dims 0 and 2.

EXECUTION STATE
Q =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
106K = np.array([...])

Key matrix. Each row is what that token 'advertises'. K[cat]=[1,0,1,0] matches Q[The] perfectly.

EXECUTION STATE
K =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
114V = np.array([...])

Value matrix. The actual content retrieved when a token is attended to. V[cat]=[0,1,0,0] — all content in dim 1.

EXECUTION STATE
V =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
123attn = ScaledDotProductAttention(d_k=4)

Instantiate the class. Calls __init__(d_k=4), setting self.d_k=4 and self.scale=2.0.

EXECUTION STATE
attn.d_k = 4
attn.scale = 2.0
124weights, output = attn.forward(Q, K, V)

Runs the full pipeline: compute_scores → scale_scores → compute_weights → compute_output. Returns both matrices.

EXECUTION STATE
weights =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
output =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
133attn.explain(Q, K, V, tokens, query_idx=0)

Print detailed trace for 'The' (token 0). Hover loop lines 78, 82, 86 above to see each iteration.

EXECUTION STATE
query_idx = 0 → tracing 'The'
91 lines without explanation
1import numpy as np
2import math
3
4class ScaledDotProductAttention:
5    """
6    Scaled Dot-Product Attention (Vaswani et al., 2017)
7
8    Computes: Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
9
10    This is the foundational attention mechanism used in every transformer.
11    All 15 mechanisms in this book build upon or modify this core computation.
12    """
13
14    def __init__(self, d_k: int):
15        """
16        Args:
17            d_k: Dimension of query/key vectors (used for scaling)
18        """
19        self.d_k = d_k
20        self.scale = math.sqrt(d_k)
21
22    def _softmax(self, x: np.ndarray) -> np.ndarray:
23        """Numerically stable softmax along last axis."""
24        x_shifted = x - np.max(x, axis=-1, keepdims=True)
25        exp_x = np.exp(x_shifted)
26        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
27
28    def compute_scores(self, Q: np.ndarray, K: np.ndarray) -> np.ndarray:
29        """Step 1: Raw dot-product scores Q @ K^T."""
30        return Q @ K.T
31
32    def scale_scores(self, scores: np.ndarray) -> np.ndarray:
33        """Step 2: Divide by sqrt(d_k) to control variance."""
34        return scores / self.scale
35
36    def compute_weights(self, scaled_scores: np.ndarray) -> np.ndarray:
37        """Step 3: Apply softmax to get attention weights."""
38        return self._softmax(scaled_scores)
39
40    def compute_output(self, weights: np.ndarray, V: np.ndarray) -> np.ndarray:
41        """Step 4: Weighted sum of value vectors."""
42        return weights @ V
43
44    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
45        """
46        Full forward pass.
47
48        Args:
49            Q: Query matrix  (N, d_k)
50            K: Key matrix    (N, d_k)
51            V: Value matrix  (N, d_v)
52
53        Returns:
54            weights: Attention weight matrix  (N, N)
55            output:  Context-enriched output  (N, d_v)
56        """
57        raw_scores    = self.compute_scores(Q, K)       # (N, N)
58        scaled_scores = self.scale_scores(raw_scores)    # (N, N)
59        weights       = self.compute_weights(scaled_scores)  # (N, N)
60        output        = self.compute_output(weights, V)  # (N, d_v)
61        return weights, output
62
63    def explain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
64                tokens: list, query_idx: int = 0):
65        """
66        Print a detailed trace of the attention computation
67        for a specific query token.
68        """
69        raw_scores    = self.compute_scores(Q, K)
70        scaled_scores = self.scale_scores(raw_scores)
71        weights       = self.compute_weights(scaled_scores)
72        output        = self.compute_output(weights, V)
73
74        token = tokens[query_idx]
75        print(f"\n=== Attention trace for '{token}' (row {query_idx}) ===")
76        print(f"Q[{query_idx}] = {Q[query_idx]}")
77        print(f"\nStep 1 - Raw dot products (Q @ K^T):")
78        for j, t in enumerate(tokens):
79            print(f"  Q[{token}] . K[{t}] = {raw_scores[query_idx, j]:.4f}")
80
81        print(f"\nStep 2 - Scaled (/ sqrt({self.d_k}) = / {self.scale:.1f}):")
82        for j, t in enumerate(tokens):
83            print(f"  S[{token},{t}] = {scaled_scores[query_idx, j]:.4f}")
84
85        print(f"\nStep 3 - Softmax weights:")
86        for j, t in enumerate(tokens):
87            bar = '#' * int(weights[query_idx, j] * 40)
88            print(f"  A[{token},{t}] = {weights[query_idx, j]:.4f} |{bar}|")
89
90        print(f"\nStep 4 - Output (weighted sum of V):")
91        print(f"  O[{token}] = {output[query_idx]}")
92        print(f"  Sum of weights = {weights[query_idx].sum():.6f}")
93
94
95# ── Shared Example (used in every chapter) ──
96tokens = ["The", "cat", "sat", "on", "mat"]
97
98Q = np.array([
99    [1.0, 0.0, 1.0, 0.0],   # The
100    [0.0, 2.0, 0.0, 1.0],   # cat
101    [1.0, 1.0, 1.0, 0.0],   # sat
102    [0.0, 0.0, 1.0, 1.0],   # on
103    [1.0, 0.0, 0.0, 1.0],   # mat
104])
105
106K = np.array([
107    [0.0, 1.0, 0.0, 1.0],   # The
108    [1.0, 0.0, 1.0, 0.0],   # cat
109    [1.0, 1.0, 0.0, 0.0],   # sat
110    [0.0, 0.0, 1.0, 1.0],   # on
111    [1.0, 0.0, 0.5, 0.5],   # mat
112])
113
114V = np.array([
115    [1.0, 0.0, 0.0, 0.0],   # The
116    [0.0, 1.0, 0.0, 0.0],   # cat
117    [0.0, 0.0, 1.0, 0.0],   # sat
118    [0.0, 0.0, 0.0, 1.0],   # on
119    [0.5, 0.5, 0.5, 0.5],   # mat
120])
121
122# ── Run ──
123attn = ScaledDotProductAttention(d_k=4)
124weights, output = attn.forward(Q, K, V)
125
126print("Attention Weight Matrix (5x5):")
127print(np.round(weights, 4))
128
129print("\nOutput Matrix (5x4):")
130print(np.round(output, 4))
131
132# Detailed trace for "The" (token 0)
133attn.explain(Q, K, V, tokens, query_idx=0)

PyTorch Implementation

The NumPy version above is ideal for understanding every arithmetic step. But in practice, you will use PyTorch — it gives you GPU acceleration, automatic differentiation, and a built-in optimized implementation. Below is the same scaled dot-product attention as an nn.Module, using the identical shared example.

Three things change from NumPy to PyTorch:

  1. Tensors replace arrays. torch.tensor tracks computation graphs for gradient computation. np.array does not.
  2. Batching is built in. torch.matmul and K.transpose(-2, -1) work for any number of leading batch dimensions. The NumPy K.T only works for 2D matrices.
  3. Masking support. The PyTorch version accepts an optional mask tensor that sets masked positions to -\infty before softmax, so e=0e^{-\infty} = 0 and those tokens receive zero attention. This is how causal masking (Chapter 3) and padding masks are implemented.
Scaled Dot-Product Attention — PyTorch Implementation
🐍scaled_dot_product_attention_torch.py
1Import PyTorch

torch is the core tensor library. torch.nn provides neural network building blocks (nn.Module). torch.nn.functional provides stateless operations like softmax. math is used for sqrt.

EXAMPLE
import torch; x = torch.tensor([1.0, 2.0])
6nn.Module subclass

By subclassing nn.Module, our attention class gets automatic parameter tracking, .cuda() for GPU transfer, gradient computation, and integration with PyTorch's optimizer/training loop. The NumPy version is a plain class — this one is a trainable component.

20Function: __init__(self, d_k: int)

Constructor. Receives d_k and precomputes the scale factor. nn.Module requires super().__init__() to register the module. self.scale is a plain float — it's not a learned parameter.

EXAMPLE
attn = ScaledDotProductAttention(d_k=64).cuda()  # ready for GPU
EXECUTION STATE
⬇ input: d_k = 4
Line 20 → super().__init__() = registers nn.Module
Line 21 → self.d_k = d_k = 4
Line 22 → self.scale = √d_k = √4 = 2.0
24Function: forward(self, Q, K, V, mask) → (weights, output)

The main entry point. Receives Q, K, V tensors and an optional mask. Chains 5 steps: matmul → scale → mask → softmax → weighted sum. Returns both attention weights and context-enriched output. Click steps 44, 47, 50, 54, 57 to see each intermediate value.

EXECUTION STATE
⬇ input: Q = torch.Size([5, 4]) — query tensor
⬇ input: K = torch.Size([5, 4]) — key tensor
⬇ input: V = torch.Size([5, 4]) — value tensor
⬇ input: mask = None (no masking)
Line 44 → scores = torch.Size([5, 5]) via matmul(Q, K.T)
Line 47 → scores /= scale = torch.Size([5, 5]) — all values halved
Line 50 → mask = skipped (mask is None)
Line 54 → weights = torch.Size([5, 5]) via F.softmax
Line 57 → output = torch.Size([5, 4]) via matmul(weights, V)
⬆ return: (weights, output) = ([5,5], [5,4]) — both tensors
44Step 1: torch.matmul

torch.matmul is the PyTorch equivalent of NumPy's @ operator. K.transpose(-2, -1) swaps the last two dimensions — this works for any number of batch dims. Compare with NumPy: Q @ K.T only works for 2D arrays.

EXAMPLE
torch.matmul(Q, K.transpose(-2, -1))  # same as Q @ K.mT in PyTorch 2.0+
EXECUTION STATE
torch.matmul() = PyTorch matrix multiply — supports batched inputs unlike NumPy's @
K.transpose(-2, -1) = swap the last two dimensions of K. -2 = second-to-last axis, -1 = last axis. For 2D: same as K.T. For 4D (batch,heads,N,d_k): transposes only the inner (N,d_k) → (d_k,N), leaving batch/head dims untouched.
K.transpose(-2, -1).shape = torch.Size([4, 5])
scores =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
47Step 2: Scale

Identical to NumPy version. PyTorch broadcasts the scalar division across all elements. Gradients flow through this operation automatically via autograd.

EXECUTION STATE
self.scale = 2.0
scores (after scaling) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
50Step 3: Masking

masked_fill sets masked positions to -inf BEFORE softmax, so exp(-inf) = 0 and those positions get zero attention weight. This is how causal masking (Chapter 3) and padding masks work. The NumPy version doesn't support masking — this is a key advantage of the PyTorch version.

EXAMPLE
# Causal mask: mask future tokens
mask = torch.triu(torch.ones(5, 5, dtype=torch.bool), diagonal=1)
EXECUTION STATE
scores.masked_fill(mask, val) = everywhere mask is True, replace score with val. Does NOT modify the original tensor — returns a new one.
float("-inf") = -∞. After softmax, exp(-∞) = 0, so masked positions get zero attention weight. This is why we use -inf and not just a large negative number.
mask = None → this line is skipped entirely (no masking in this example)
scores (unchanged) = same as Step 2 output
54Step 4: F.softmax

F.softmax is numerically stable by default (it subtracts the max internally). Compare with our NumPy version where we had to implement the max-subtraction trick manually.

EXECUTION STATE
F.softmax() = PyTorch's built-in softmax — numerically stable (subtracts max internally), unlike raw exp()/sum().
dim=-1 = normalize along the LAST dimension (columns within each row). Each row independently becomes a probability distribution summing to 1.0. Same as axis=-1 in NumPy.
weights =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
row sums = [1.0, 1.0, 1.0, 1.0, 1.0] ✓
57Step 5: Weighted sum

Matrix multiply weights (N, S) with V (S, d_v) to get output (N, d_v). Identical math to the NumPy version. The gradient of this operation with respect to weights and V is computed automatically by autograd.

EXECUTION STATE
output =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
64torch.tensor vs np.array

torch.tensor creates a PyTorch tensor. Unlike np.array, tensors track computation graphs for automatic differentiation, can be moved to GPU with .cuda(), and support broadcasting across batch dimensions.

EXECUTION STATE
Q =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
Q.requires_grad = False (no gradient tracking for input data)
91Calling the module

attn(Q, K, V) calls attn.forward(Q, K, V) under the hood — nn.Module's __call__ wraps forward() with hooks, gradient tracking, and other PyTorch machinery. Never call .forward() directly.

EXAMPLE
# Wrong: attn.forward(Q, K, V)
# Right: attn(Q, K, V)
EXECUTION STATE
d_k = 4
scale = 2.0
weights (5×5) =
       The      cat      sat       on      mat
The  0.1095   0.2976   0.1805   0.1805   0.2318
cat  0.4026   0.0898   0.2442   0.1481   0.1153
sat  0.1519   0.2505   0.2505   0.1519   0.1951
on   0.1903   0.1903   0.1154   0.3137   0.1903
mat  0.1892   0.1892   0.1892   0.1892   0.2430
output (5×4) =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
98PyTorch built-in SDPA

Since PyTorch 2.0, F.scaled_dot_product_attention is a built-in that automatically selects the fastest backend: FlashAttention, Memory-Efficient Attention, or math fallback. It accepts 3D or higher input; the common production shape is 4D: (batch, heads, seq_len, d_k). Our manual implementation produces identical results.

EXECUTION STATE
Q_b.shape = torch.Size([1, 1, 5, 4])
builtin_out.shape = torch.Size([5, 4]) (after squeeze)
torch.allclose(output, builtin_out) = True ✓
100unsqueeze for batch/head dims

unsqueeze(0) adds a dimension at position 0. We add two: one for batch (size 1) and one for heads (size 1). This converts (5, 4) → (1, 1, 5, 4), the standard 4D multi-head shape.

EXECUTION STATE
.unsqueeze(0) = insert a new dimension of size 1 at position 0. The 0 means 'before the first existing dimension'.
Q.shape = torch.Size([5, 4])
after 1st unsqueeze(0) = torch.Size([1, 5, 4]) — added batch dim
after 2nd unsqueeze(0) = torch.Size([1, 1, 5, 4]) — added head dim
meaning of (1,1,5,4) = 1 batch × 1 head × 5 tokens × 4 dims
108GPU acceleration

Moving tensors to GPU is one line: .cuda(). The same attention code runs on CPU or GPU without any changes. For our tiny 5×4 example the GPU overhead isn't worth it — but at production scale (seq_len=8192, d_k=128), GPU gives 100x+ speedup.

EXECUTION STATE
torch.cuda.is_available() = True/False (depends on hardware)
GPU output matches CPU? = True ✓ (identical computation)
104 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class ScaledDotProductAttention(nn.Module):
7    """
8    Scaled Dot-Product Attention (Vaswani et al., 2017) — PyTorch
9
10    Computes: Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
11
12    This nn.Module version supports:
13      - GPU acceleration (just move tensors to CUDA)
14      - Automatic differentiation (gradients flow through)
15      - Batched inputs (batch_size, N, d_k)
16      - Optional attention mask for causal / padding masks
17    """
18
19    def __init__(self, d_k: int):
20        super().__init__()
21        self.d_k = d_k
22        self.scale = math.sqrt(d_k)
23
24    def forward(
25        self,
26        Q: torch.Tensor,
27        K: torch.Tensor,
28        V: torch.Tensor,
29        mask: torch.Tensor | None = None,
30    ) -> tuple[torch.Tensor, torch.Tensor]:
31        """
32        Args:
33            Q: Query tensor   (..., N, d_k)
34            K: Key tensor     (..., S, d_k)
35            V: Value tensor   (..., S, d_v)
36            mask: Optional boolean mask (..., N, S)
37                  True = position is MASKED (ignored)
38
39        Returns:
40            weights: Attention weights  (..., N, S)
41            output:  Contextualized     (..., N, d_v)
42        """
43        # Step 1: Raw scores — matrix multiply Q with K transposed
44        scores = torch.matmul(Q, K.transpose(-2, -1))   # (..., N, S)
45
46        # Step 2: Scale — divide by sqrt(d_k)
47        scores = scores / self.scale
48
49        # Step 3 (optional): Apply mask before softmax
50        if mask is not None:
51            scores = scores.masked_fill(mask, float("-inf"))
52
53        # Step 4: Softmax — normalize each row to probabilities
54        weights = F.softmax(scores, dim=-1)              # (..., N, S)
55
56        # Step 5: Weighted sum of value vectors
57        output = torch.matmul(weights, V)                # (..., N, d_v)
58
59        return weights, output
60
61
62# ── Shared Example (same matrices as NumPy version) ──
63tokens = ["The", "cat", "sat", "on", "mat"]
64
65Q = torch.tensor([
66    [1.0, 0.0, 1.0, 0.0],   # The
67    [0.0, 2.0, 0.0, 1.0],   # cat
68    [1.0, 1.0, 1.0, 0.0],   # sat
69    [0.0, 0.0, 1.0, 1.0],   # on
70    [1.0, 0.0, 0.0, 1.0],   # mat
71])
72
73K = torch.tensor([
74    [0.0, 1.0, 0.0, 1.0],   # The
75    [1.0, 0.0, 1.0, 0.0],   # cat
76    [1.0, 1.0, 0.0, 0.0],   # sat
77    [0.0, 0.0, 1.0, 1.0],   # on
78    [1.0, 0.0, 0.5, 0.5],   # mat
79])
80
81V = torch.tensor([
82    [1.0, 0.0, 0.0, 0.0],   # The
83    [0.0, 1.0, 0.0, 0.0],   # cat
84    [0.0, 0.0, 1.0, 0.0],   # sat
85    [0.0, 0.0, 0.0, 1.0],   # on
86    [0.5, 0.5, 0.5, 0.5],   # mat
87])
88
89# ── Run ──
90attn = ScaledDotProductAttention(d_k=4)
91weights, output = attn(Q, K, V)
92
93print("Attention Weight Matrix (5x5):")
94print(weights.round(decimals=4))
95
96print("\nOutput Matrix (5x4):")
97print(output.round(decimals=4))
98
99# ── Verify: Use PyTorch's built-in SDPA ──
100with torch.no_grad():
101    # F.scaled_dot_product_attention expects (batch, heads, N, d_k)
102    Q_b = Q.unsqueeze(0).unsqueeze(0)   # (1, 1, 5, 4)
103    K_b = K.unsqueeze(0).unsqueeze(0)
104    V_b = V.unsqueeze(0).unsqueeze(0)
105
106    builtin_out = F.scaled_dot_product_attention(Q_b, K_b, V_b)
107    builtin_out = builtin_out.squeeze(0).squeeze(0)  # back to (5, 4)
108
109print("\nBuilt-in SDPA output matches ours?",
110      torch.allclose(output, builtin_out, atol=1e-4))
111
112# ── GPU: just move tensors ──
113if torch.cuda.is_available():
114    Q_gpu, K_gpu, V_gpu = Q.cuda(), K.cuda(), V.cuda()
115    attn_gpu = attn.cuda()
116    w_gpu, o_gpu = attn_gpu(Q_gpu, K_gpu, V_gpu)
117    print("\nGPU output matches CPU?",
118          torch.allclose(output, o_gpu.cpu(), atol=1e-4))

PyTorch Built-in: F.scaled_dot_product_attention

Since PyTorch 2.0, there is a single built-in function that computes exactly what our class does — but it automatically selects the fastest available backend:

BackendWhen UsedSpeedup
FlashAttention-2CUDA GPU, no custom mask, fp16/bf162-4x
Memory-EfficientCUDA GPU with masks, any dtype1.5-3x
Math fallbackCPU, or when above backends unavailable1x (baseline)

The call is a single line. Click any highlighted line to see the exact tensor shapes and output values — including what happens when you enable causal masking:

Built-in SDPA — One Line, Same Result
🐍builtin_sdpa.py
1Import functional API

torch.nn.functional contains stateless operations — functions that don't hold parameters. F.scaled_dot_product_attention is the built-in SDPA that auto-selects the fastest backend.

4Reshape Q for built-in API

unsqueeze(0) adds a dimension at position 0. Two unsqueezes convert (5, 4) → (1, 1, 5, 4), adding batch and head dimensions. The built-in also accepts 3D input directly.

EXECUTION STATE
Q.shape = torch.Size([5, 4])
Q.unsqueeze(0).shape = torch.Size([1, 5, 4])
Q_b.shape = torch.Size([1, 1, 5, 4])
5Reshape K

Same unsqueeze operation for K. All three tensors must have matching batch and head dimensions.

EXECUTION STATE
K_b.shape = torch.Size([1, 1, 5, 4])
6Reshape V

V gets the same treatment. Now all three are (1, 1, 5, 4) — 1 batch, 1 head, 5 tokens, 4 dims.

EXECUTION STATE
V_b.shape = torch.Size([1, 1, 5, 4])
8Built-in SDPA call

One line replaces the entire manual class. PyTorch auto-selects the backend: FlashAttention-2 on CUDA with fp16/bf16, Memory-Efficient for masked inputs, or math fallback on CPU. The output is mathematically identical to our manual implementation.

EXECUTION STATE
output.shape = torch.Size([1, 1, 5, 4])
output.squeeze() =
        d0       d1       d2       d3
The  0.2254   0.4135   0.2964   0.2964
cat  0.4602   0.1475   0.3018   0.2058
sat  0.2495   0.3481   0.3481   0.2495
on   0.2854   0.2854   0.2106   0.4089
mat  0.3108   0.3108   0.3108   0.3108
matches manual? = True ✓ (torch.allclose, atol=1e-4)
12Causal (masked) SDPA

is_causal=True applies an upper-triangular mask so each token can only attend to itself and earlier tokens. Future tokens get -inf before softmax → zero weight. This is what GPT and all autoregressive models use.

EXECUTION STATE
is_causal=True = tells PyTorch to automatically create an upper-triangular boolean mask and apply it before softmax. You don't need to build the mask yourself — PyTorch generates it internally.
causal mask =
     0     1     2     3     4
0  attn   -∞    -∞    -∞    -∞
1  attn  attn   -∞    -∞    -∞
2  attn  attn  attn   -∞    -∞
3  attn  attn  attn  attn   -∞
4  attn  attn  attn  attn  attn
causal weights =
       The      cat      sat       on      mat
The  1.0000   0.0000   0.0000   0.0000   0.0000
cat  0.8176   0.1824   0.0000   0.0000   0.0000
sat  0.2327   0.3837   0.3837   0.0000   0.0000
on   0.2350   0.2350   0.1425   0.3875   0.0000
mat  0.1892   0.1892   0.1892   0.1892   0.2430
output_causal.squeeze() =
        d0       d1       d2       d3
The  1.0000   0.0000   0.0000   0.0000
cat  0.8176   0.1824   0.0000   0.0000
sat  0.2327   0.3837   0.3837   0.0000
on   0.2350   0.2350   0.1425   0.3875
mat  0.3108   0.3108   0.3108   0.3108
8 lines without explanation
1import torch.nn.functional as F
2
3# Common production shape is (batch, heads, seq_len, d_k); 3D also works
4Q_b = Q.unsqueeze(0).unsqueeze(0)   # (1, 1, 5, 4)
5K_b = K.unsqueeze(0).unsqueeze(0)
6V_b = V.unsqueeze(0).unsqueeze(0)
7
8output = F.scaled_dot_product_attention(Q_b, K_b, V_b)
9# output.shape = (1, 1, 5, 4) — squeeze to get (5, 4)
10
11# With causal mask (autoregressive):
12output_causal = F.scaled_dot_product_attention(
13    Q_b, K_b, V_b, is_causal=True
14)

When to use which

Use the built-in F.scaled_dot_product_attention in production code — it is faster, memory-efficient, and battle-tested. Use the manual nn.Module class when you need to inspect intermediate values (attention weights), add custom modifications, or learn how the mechanism works.

NumPy vs PyTorch — Side-by-Side

AspectNumPyPyTorch
Data typenp.ndarraytorch.Tensor
Matrix multiplyQ @ K.Ttorch.matmul(Q, K.transpose(-2, -1))
SoftmaxManual (exp, sum, divide)F.softmax(scores, dim=-1)
GradientsNot supportedAutomatic (autograd)
GPUNot supported.cuda() / .to("cuda")
Batching2D only (N, d_k)Any dims (..., N, d_k)
MaskingNot built inmasked_fill(mask, -inf)
Best forLearning, debugging, prototypingTraining, inference, production
OutputIdentical — both produce the same attention weights and output matrices

Key Takeaways

  1. The core formula softmax(QK/dk)V\text{softmax}(QK^\top / \sqrt{d_k})\,V performs four operations: compare (dot product), scale, normalize (softmax), and aggregate (weighted sum).
  2. Scaling by 1/dk1/\sqrt{d_k} is not arbitrary — it is the exact factor needed to keep the variance of dot products at 1, preventing softmax saturation and gradient death.
  3. Every token talks to every other token in parallel, solving the RNN's sequential bottleneck. The computational cost is O(N2dk)O(N^2 d_k), which is the price of all-pairs comparison.
  4. Q, K, V serve distinct roles: Q asks the question, K advertises what's available, V provides the content. The learned projections WQ,WK,WVW^Q, W^K, W^V give the model flexibility to learn task-specific notions of relevance.
  5. This mechanism is the atom from which all 15 variants in this book are built. Multi-head attention runs it in parallel subspaces. Flash Attention computes it with better memory access. Causal masking restricts which keys each query can see. RoPE rotates Q and K to encode position. Understanding this chapter deeply makes every subsequent chapter immediate.

Exercises

These exercises reinforce the concepts from this chapter. Work through them by hand first, then verify with the Python class above.

Exercise 1: Compute Attention for “sat” (Row 2)

Using Qsat=[1,1,1,0]Q_{\text{sat}} = [1, 1, 1, 0] and the same KK and VV matrices, compute the four steps by hand: raw dot products, scaled scores, softmax weights, and the output vector. Which token does “sat” attend to most, and why?

Hint: “sat”'s query is active in dims 0, 1, and 2. Look for keys with the highest overlap in those same dimensions.

Exercise 2: What If We Skip Scaling?

Recompute the softmax weights for “The” (row 0) without dividing by dk\sqrt{d_k}. Compare the resulting weight distribution to the scaled version. How much sharper is the unscaled distribution? Now imagine dk=512d_k = 512 — what would happen to the softmax output?

Exercise 3: Identical Q and K

Suppose we set Q=KQ = K (every token's query equals its own key). What pattern would you expect in the attention weight matrix? Would every token attend most to itself? Test your prediction by modifying the Python class: set K = Q.copy() and run the code.

Exercise 4: Orthogonal Keys

Design a KK matrix where every key vector is orthogonal to every other key vector (hint: use the identity matrix for the first 4 tokens). What does the attention weight matrix look like? What does this tell you about how the model behaves when keys carry maximally distinct information?

Exercise 5: Scale Factor Derivation

The variance proof assumes E[qm]=0\mathbb{E}[q_m] = 0 and Var(qm)=1\text{Var}(q_m) = 1. What if the elements have Var(qm)=σ2\text{Var}(q_m) = \sigma^2 instead of 1? Derive the new variance of the dot product and determine what scaling factor would be needed to normalize it back to 1.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). “Attention Is All You Need.” Advances in Neural Information Processing Systems, 30. The paper that introduced the Transformer and scaled dot-product attention.
  2. Bahdanau, D., Cho, K., & Bengio, Y. (2015). “Neural Machine Translation by Jointly Learning to Align and Translate.” ICLR 2015. The first attention mechanism for sequence-to-sequence models (additive attention).
  3. Sutskever, I., Vinyals, O., & Le, Q.V. (2014). “Sequence to Sequence Learning with Neural Networks.” NeurIPS 2014. The seq2seq architecture that exposed the encoder bottleneck problem.
  4. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. IO-aware tiling of the attention computation.
  5. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). “An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale.” ICLR 2021. Vision Transformer (ViT).
  6. Jumper, J., Evans, R., Pritzel, A., et al. (2021). “Highly Accurate Protein Structure Prediction with AlphaFold.” Nature, 596, 583-589. Attention applied to protein folding.
  7. Jain, S. & Wallace, B.C. (2019). “Attention is not Explanation.” NAACL 2019. Cautionary analysis of over-interpreting attention weights.
Loading comments...