Chapter 16
45 min read
Section 17 of 17

Final Comparison

Final Comparison

Learning Objectives

By the end of this chapter, you will be able to:

  1. Compare the mathematical formulas of all 15 attention mechanisms and explain what each modification does to the standard softmax(QK/dk)V\text{softmax}(QK^\top/\sqrt{d_k})\,V formula.
  2. Predict how each mechanism will change the output vector for a given token based on its structural constraints (masking, position encoding, sparsity, compression).
  3. Select the right mechanism for a given deployment scenario by reasoning about the trade-offs between quality, memory, speed, and context length.
  4. Implement all 15 mechanisms in a single unified Python/PyTorch class and verify that they produce different outputs from the same input.
  5. Explain how modern systems like LLaMA 3, Qwen-2, Mistral, and DeepSeek-V2 combine multiple mechanisms (e.g., GQA + RoPE + Flash + Sliding Window) into a single architecture.

The Story: Why Fifteen Mechanisms?

In 2017, Vaswani et al. published "Attention Is All You Need" and introduced a single formula: Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V. That one equation launched the transformer revolution. But it also had limitations: it was O(N2)O(N^2) in sequence length, had no notion of position, allocated equal memory to every head, and treated every token pair with the same importance.

Over the next seven years (2017–2024), researchers attacked each of these limitations independently. The result is not a single "best" attention mechanism, but a toolkit of fifteen complementary designs, each making a different trade-off. Understanding all of them — and knowing when to combine them — is what separates an engineer who uses transformers from one who designs them.

The Four Forces of Attention Design

Every attention mechanism navigates four competing forces:

ForceWantsCosts
QualityFull pairwise interaction between all N tokensO(N\u00b2) compute and memory
SpeedSub-quadratic complexity or hardware-efficient implementationMay approximate or restrict the attention pattern
MemorySmall KV-cache for long-context inferenceSharing or compressing keys/values loses information
PositionTokens to know their relative or absolute positionsExtra parameters, computation, or inductive bias

No mechanism achieves all four simultaneously. The fifteen chapters of this book represent fifteen different points in this trade-off space.


The Mathematical Landscape

Every mechanism modifies the same base equation. Here are all fifteen formulas in one place, organized by what they change. In every case, Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}, N=5N = 5 tokens, d=4d = 4 dimensions.

Foundation Mechanisms (1–4)

1. Scaled Dot-Product (Vaswani et al., 2017): The baseline. Output=softmax ⁣(QKd)V\text{Output} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right) V. Every token attends to every other token. Full O(N2d)O(N^2 d) compute.

2. Multi-Head Attention (Vaswani et al., 2017): Split Q,K,VQ, K, V into HH heads of dimension dh=d/Hd_h = d/H. Compute attention independently per head, then concatenate: Output=[head1;head2;;headH]\text{Output} = [\text{head}_1; \text{head}_2; \ldots; \text{head}_H] where headh=softmax(QhKh/dh)Vh\text{head}_h = \text{softmax}(Q_h K_h^\top / \sqrt{d_h})\, V_h.

3. Causal (Masked) Self-Attention (Radford et al., 2018): Add -\infty to future positions: Output=softmax ⁣(QKd+Mcausal)V\text{Output} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + M_{\text{causal}}\right) V where Mij=0M_{ij} = 0 if jij \leq i, -\infty otherwise. Token ii can only attend to tokens 0,1,,i0, 1, \ldots, i.

4. Cross-Attention (Vaswani et al., 2017): Queries come from one sequence (decoder), keys and values from another (encoder): Output=softmax ⁣(QdecKencd)Venc\text{Output} = \text{softmax}\!\left(\frac{Q_{\text{dec}} K_{\text{enc}}^\top}{\sqrt{d}}\right) V_{\text{enc}}. The bridge between encoder and decoder.

Efficiency Mechanisms (5–6, 10, 13)

5. Multi-Query Attention (Shazeer, 2019): All HH heads share a single K and V: headh=softmax(QhKshared/dh)Vshared\text{head}_h = \text{softmax}(Q_h K_{\text{shared}}^\top / \sqrt{d_h})\, V_{\text{shared}}. KV-cache reduced by H×H\times. Quality drops roughly 1%.

6. Grouped-Query Attention (Ainslie et al., 2023): Compromise between MHA and MQA. Partition HH heads into GG groups. Heads in the same group share K and V. When G=1G = 1 it reduces to MQA; when G=HG = H it reduces to MHA.

10. Linear Attention (Katharopoulos et al., 2020): Replace softmax with a feature map ϕ\phi: Outputi=ϕ(Qi)(ϕ(K)V)ϕ(Qi)jϕ(Kj)\text{Output}_i = \frac{\phi(Q_i)^\top (\phi(K)^\top V)}{\phi(Q_i)^\top \sum_j \phi(K_j)}. Rearranging removes the N×NN \times N attention matrix entirely. Complexity becomes O(Nd2)O(N d^2) instead of O(N2d)O(N^2 d).

13. Flash Attention (Dao et al., 2022): Same formula as #1, but computed using IO-aware tiling that avoids materializing the N×NN \times N attention matrix in GPU HBM. Exact same output, 3–8×\times faster.

Position Mechanisms (7–9)

7. Relative Position Bias (Shaw et al., 2018; Raffel et al., 2020): Add a learned bias based on relative distance: Output=softmax ⁣(QKd+B)V\text{Output} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + B\right) V where Bij=b(ij)B_{ij} = b(|i - j|) is a function of distance (e.g., 0.5ij-0.5 \cdot |i - j|).

8. RoPE (Su et al., 2021): Rotate QQ and KK by position-dependent angles: Q=Rθ,mQQ' = R_{\theta,m}\, Q, K=Rθ,nKK' = R_{\theta,n}\, K where RθR_\theta is a block-diagonal rotation matrix. The dot product QmKnQ'_m \cdot K'_n depends only on mnm - n, encoding relative position without extra parameters.

9. ALiBi (Press et al., 2022): Linear penalty Output=softmax ⁣(QKdmij)V\text{Output} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} - m \cdot |i - j|\right) V where mm is a fixed slope per head. No learned position parameters at all. Zero extra memory. Trains faster, generalizes to longer sequences.

Sparse Mechanisms (11–12)

11. Sliding Window Attention (Beltagy et al., 2020): Each token attends to a fixed window of 2W+12W + 1 neighbors: αij=0\alpha_{ij} = 0 if ij>W|i - j| > W. Complexity O(NW)O(NW). Used in Mistral with W=4096W = 4096.

12. Sparse Attention (BigBird) (Zaheer et al., 2020): Combines local window + global tokens + random connections. Global tokens (e.g., [CLS]) attend to and from all positions. Maintains universal approximation while achieving O(N)O(N) complexity.

Advanced Mechanisms (14–15)

14. Differential Attention (Microsoft Research, 2024): Split QQ and KK in half, compute two attention maps A1A_1 and A2A_2, then subtract: DiffAttn=max(A1λA2,0)V\text{DiffAttn} = \max(A_1 - \lambda A_2,\, 0) \cdot V followed by renormalization. The subtraction cancels noise, concentrating weight on truly relevant tokens.

15. Multi-Head Latent Attention (MLA) (DeepSeek-V2, 2024): Compress KK and VV into a shared low-rank latent cKV=KWdownc_{KV} = K \cdot W_{\text{down}}, then reconstruct: K=cKVWUKK' = c_{KV} \cdot W_{UK}, V=cKVWUVV' = c_{KV} \cdot W_{UV}. Only cKVc_{KV} is cached. KV-cache reduced by d/dcd/d_c times.


The Shared Example: Results Side by Side

Every chapter processed the same sentence — "The cat sat on the mat" — with the same QQ, KK, VV matrices. The table below shows what each mechanism produced for the token "cat" (row 1). These are exact computed values, not approximations.

Methoddim-0dim-1dim-2dim-3Key Difference
01. Scaled Dot-Product0.51790.08980.35950.1481Baseline: full attention
02. Multi-Head (H=2)0.45550.08910.32410.2711Independent sub-spaces, dim-3 lifted
03. Causal (Masked)0.81760.18240.00000.0000No future tokens: dims 2, 3 collapse
04. Cross-Attention0.48220.12050.35340.1986Decoder QQ shifts weight
05. Multi-Query (MQA)0.45550.08910.42910.1417Shared K,VK, V changes dim-2
06. GQA (H=G=2)0.45550.08910.32410.2711=MHA when H=GH = G
07. Rel. Position Bias0.48000.15970.30910.0969Distance penalty: dim-3 drops
08. RoPE0.39390.09120.49890.1518Rotation shifts to dim-2 (sat)
09. ALiBi0.43510.25410.27030.0567Strong local focus: dim-1 rises
10. Linear0.41750.17480.39810.1942Flattest: most uniform blend
11. Sliding Window (W=1)0.54650.12200.33150.0000No on/mat: dim-3 zero
12. Sparse BigBird0.54650.12200.33150.0000Same as window for cat row
13. Flash Attention0.51790.08980.35950.1481= #01 (hardware optimization)
14. Differential0.41770.04020.54210.0000Sharpest: noise cancelled
15. MLA0.37260.60740.37260.6074Compressed K,VK, V: mirrored output

Six Key Insights from the Comparison

  1. Flash = Standard (#13 = #01). Flash Attention is purely a hardware optimization. The outputs are bit-for-bit identical. If you only care about the math, Flash is Chapter 1. The engineering lesson: identical math can have wildly different runtime performance.
  2. Causal masking dominates (#03). Restricting "cat" to only see "The" and itself changes the output dramatically — dims 2 and 3 collapse to zero because only V[The]V[\text{The}] and V[cat]V[\text{cat}] contribute, and those rows of VV have zeros in dims 2–3. This reveals a fundamental principle: what you prevent a token from seeing matters as much as what you let it see.
  3. Local constraints produce zeros (#11, #12, #14). Whenever a mechanism prevents "cat" from attending to "on" or "mat", dim-3 drops to zero. Why? Because V[on]=[0,0,0,1]V[\text{on}] = [0, 0, 0, 1] is the only token with non-zero dim-3. The output is a direct readout of which tokens were attended to.
  4. Differential attention is sharpest (#14). The noise cancellation concentrates 54.2% of weight on "sat" (vs 24.4% in standard). "Cat" is telling us it is doing something ("sat") rather than distributing credit equally. This explains why Diff-Attention improves retrieval tasks.
  5. Linear attention is flattest (#10). Without softmax to sharpen the distribution, weights spread uniformly. Every output dimension stays between 0.17 and 0.42 — the output is an almost-equal blend of all value vectors. Linear attention trades sharpness for O(Nd2)O(Nd^2) speed.
  6. MLA creates mirrored outputs (#15). Because K and V are reconstructed from the same compressed latent cKVc_{KV} using identical projection matrices in our example, dims 0 and 2 are equal, and dims 1 and 3 are equal. The compression bottleneck forces KVK' \approx V', reducing the rank of the output. In trained models, separate WUKW_{UK} and WUVW_{UV} break this symmetry.

Interactive: Comparison Dashboard

The dashboard below lets you explore all 15 mechanisms interactively. Select a query token to see how each mechanism distributes attention weight across the five key tokens, and compare the resulting output vectors side by side.

Loading Comparison Dashboard...
Try this: Switch between tokens and observe how the attention patterns change. Notice that "mat" (the last token) has identical attention weights across Causal (#3) and Standard (#1) — because the last token can already see everything, the causal mask has no effect.

Computational Complexity Comparison

The following table summarizes the computational cost of each mechanism. Here NN is the sequence length, dd is the model dimension, HH is the number of heads, WW is the window size, and dcd_c is the compressed latent dimension.

MechanismTime ComplexityChanges What?Year
Scaled Dot-Product\text{Scaled Dot-Product}O(N2d)O(N^2 d)Baseline2017
Multi-Head\text{Multi-Head}O(N2d)O(N^2 d)Representational capacity2017
Causal\text{Causal}O(N2d)O(N^2 d)Attention pattern (mask)2018
Cross-Attention\text{Cross-Attention}O(NqNkd)O(N_q N_k d)Source of K, V2017
MQA\text{MQA}O(N2d)O(N^2 d)KV-cache (shared K,VK, V)2019
GQA\text{GQA}O(N2d)O(N^2 d)KV-cache (GG groups)2023
Relative Pos. Bias\text{Relative Pos. Bias}O(N2d)O(N^2 d)Score (additive bias)2018
RoPE\text{RoPE}O(N2d)O(N^2 d)Q, K (rotation of Q,KQ, K)2021
ALiBi\text{ALiBi}O(N2d)O(N^2 d)Score (linear penalty)2022
Linear\text{Linear}O(Nd2)O(N d^2)Removes softmax entirely2020
Sliding Window\text{Sliding Window}O(NWd)O(NW d)Restricts attention span2020
Sparse BigBird\text{Sparse BigBird}O(Nd)O(N d)Sparse attention pattern2020
Flash Attention\text{Flash Attention}O(N2d)O(N^2 d)GPU memory access pattern2022
Differential\text{Differential}O(N2d)O(N^2 d)Noise cancellation2024
MLA\text{MLA}O(N2d)O(N^2 d)KV-cache (compress to dcd_c)2024

Memory Footprint Comparison

For inference with long contexts, KV-cache memory is the bottleneck. Here is the per-token cache cost for each mechanism, assuming d=4096d = 4096, H=32H = 32 heads, and 16-bit (FP16) storage:

MechanismCache per TokenRelative to MHASavings
MHA (standard)2×H×dh=2d2 \times H \times d_h = 2d = 16 KB
MQA2×dh2 \times d_h = 0.5 KB1/H1/H32× reduction
GQA (G=8)2×G×dh2 \times G \times d_h = 4 KBG/HG/H4× reduction
MLAdcd_c (shared latent)dc/2dd_c / 2dUp to 64×\times reduction
Sliding Window2×W×d2 \times W \times d (fixed window)W/NW/NOnly 2W2W tokens cached
Practical implication: A model with H=32H = 32 heads and 128K context uses 16 GB of KV-cache with MHA. GQA (G=8G = 8) reduces this to 4 GB. MLA can reduce it to 250 MB. This is the difference between needing 8 GPUs and fitting on a single GPU.

Decision Framework

Given a deployment scenario, use this decision table to choose the right combination of mechanisms:

ScenarioRecommended StackWhy
Training an LLM from scratchMHA + RoPE + FlashBest quality; RoPE for length generalization; Flash for training speed
Fast inference, long contextGQA + RoPE + Flash48×4\text{--}8\times KV-cache reduction with minimal quality loss
Extreme memory constraintMLA + RoPE + FlashUp to 64×64\times KV-cache reduction (DeepSeek-V2 proven)
Very long documents (>32K)GQA + RoPE + Sliding WindowO(NW)O(NW) complexity; Mistral's proven recipe
Autoregressive generationCausal + GQA + RoPE + FlashStandard recipe for GPT-style models (LLaMA 3, Qwen-2)
Encoder-decoder (translation)Cross-Attention + MHA + RoPEDecoder attends to encoder output via cross-attention
Long-context retrievalDifferential + GQA + RoPENoise cancellation improves recall in long documents
Linear-time processingLinear AttentionO(Nd2)O(Nd^2) — no attention matrix at all
The meta-lesson: There is no single "best" attention mechanism. Modern production systems combine 3–5 mechanisms simultaneously. The key skill is understanding each mechanism's trade-off well enough to compose them correctly.

Connection to Modern Systems

The LLaMA/Qwen Recipe

The most successful open-source LLMs of 2024–2025 (LLaMA 3, Qwen-2, Mistral) all converged on a remarkably similar recipe:

  1. GQA (Chapter 6) for KV-cache efficiency, typically with H=32H = 32, G=8G = 8 (4×\times reduction).
  2. RoPE (Chapter 8) for position encoding, with NTK-aware scaling for extended context windows up to 128K tokens.
  3. Causal masking (Chapter 3) for autoregressive generation.
  4. Flash Attention 2/3 (Chapter 13) for training and inference speed.
  5. Sliding Window (Chapter 11) in Mistral/Mixtral for efficient long-context attention with W=4096W = 4096.

This means a single forward pass uses four or five mechanisms simultaneously: GQA defines the head sharing structure, RoPE rotates Q and K before scoring, the causal mask prevents attending to future tokens, Flash Attention handles the GPU memory access pattern, and sliding window limits the attention span for efficiency.

The Research Frontier

Two mechanisms from 2024 point toward the next generation of attention:

  • Differential Attention (Chapter 14) is being explored for retrieval-augmented generation (RAG), where noise cancellation improves the model's ability to find relevant passages in long contexts. Microsoft's results show improvements on key-value retrieval and in-context learning tasks.
  • MLA (Chapter 15) achieved the most extreme KV-cache compression in production. DeepSeek-V2 (236B parameters) uses MLA to serve 128K context windows with dramatically less memory than GQA-based competitors, while maintaining quality competitive with LLaMA 3.

Python Implementation

The following unified class implements all 15 attention mechanisms. You can run it directly to reproduce the comparison table above. Each method corresponds to one chapter of this book.

Python — All 15 Attention Mechanisms
🐍attention_comparison.py
1import numpy as np

NumPy provides vectorized matrix operations. All Q @ K.T, softmax, and output computations run as optimized C code.

2import math

math.sqrt computes the scaling factor √d_k for dot-product attention.

4class AttentionComparison:

Unified class implementing all 15 attention mechanisms from Chapters 1–15. Each method (m01 through m15) corresponds to one chapter and returns the output matrix (5×4).

11def __init__(self):

Constructor initializes the shared example: 5 tokens with 4-dimensional Q, K, V matrices. These same matrices were used in every chapter.

EXECUTION STATE
⬇ self.N = 5 — number of tokens
⬇ self.d = 4 — embedding dimension
⬇ self.scale = 2.0 (√4)
12self.tokens = ["The", "cat", "sat", "on", "mat"]

The five tokens of our running example sentence. Every mechanism processes these same tokens.

13self.Q = np.array([...])

Query matrix Q (5×4). Each row is a token’s query vector — what it is looking for.

EXECUTION STATE
Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
20self.K = np.array([...])

Key matrix K (5×4). Each row is a token’s key vector — what it offers to queries.

EXECUTION STATE
K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
28self.V = np.array([...])

Value matrix V (5×4). Each row is the information a token contributes when attended to. V is designed as an identity-like matrix so each token’s contribution is easy to trace.

EXECUTION STATE
V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  1.0  0.0  1.0  0.0
39def _softmax(self, x) -> np.ndarray:

Numerically stable softmax. Subtracts row max before exp to prevent overflow. Used by 13 of the 15 mechanisms.

EXECUTION STATE
⬇ input: x = np.ndarray — score matrix (N×N)
⬆ returns = np.ndarray — each row sums to 1.0
44def m01_scaled_dot_product(self):

Chapter 1: The foundation. Attention(Q,K,V) = softmax(QKᵀ/√d_k) · V. Every other mechanism modifies this formula.

EXECUTION STATE
⬆ returns = np.ndarray (5×4) — attended output
49def m02_multi_head(self, H=2):

Chapter 2: Split Q, K, V into H=2 independent heads, compute attention per head, concatenate. Each head learns a different relationship.

EXECUTION STATE
⬇ input: H = 2 — number of attention heads
d_h = 2 — dimension per head (d/H = 4/2)
62def m03_causal(self):

Chapter 3: Add -∞ mask to future positions. Token i can only attend to tokens 0..i. Used in all autoregressive language models (GPT, LLaMA).

EXECUTION STATE
mask shape = (5×5) upper triangle = -1e9
68def m04_cross(self):

Chapter 4: Decoder queries attend to encoder keys/values. Different Q_dec, same K and V from encoder. The bridge between encoder and decoder.

EXECUTION STATE
Q_dec (5×4) = Decoder queries — different from self.Q
80def m05_multi_query(self, H=2):

Chapter 5 (Shazeer 2019): All H heads share the same K and V projection. KV-cache reduced by H×. Quality drops ~1% but inference is much faster.

EXECUTION STATE
K_shared = K[:, :d_h] — single shared key for all heads
93def m06_grouped_query(self, H=2, G=2):

Chapter 6 (Ainslie 2023): With H=G=2, every group has 1 head, reducing to standard MHA. In practice, G < H gives the sweet spot between MHA quality and MQA speed.

96def m07_relative_pos_bias(self):

Chapter 7 (Shaw 2018 / T5 2020): Add learnable bias based on relative distance |i-j|. Nearby tokens get higher scores. Enables length generalization.

104def m08_rope(self):

Chapter 8 (Su 2021): Rotate Q and K vectors by position-dependent angles. The dot product Q_rot · K_rot encodes relative position. Used in LLaMA, Qwen, Mistral.

118def m09_alibi(self, slope=1.0):

Chapter 9 (Press 2022): Linear penalty -m·|i-j| added to scores. No learned parameters for position — just a fixed slope per head. Zero extra memory, trains faster.

126def m10_linear(self):

Chapter 10 (Katharopoulos 2020): Replace softmax with feature map φ. Rearrange as φ(Q) · (φ(K)ᵀV) to get O(Nd²) complexity instead of O(N²d). No explicit attention matrix.

136def m11_sliding_window(self, W=1):

Chapter 11 (Beltagy 2020): Each token only attends to [i-W, i+W]. O(NW) complexity. Used in Mistral (W=4096) for long contexts.

144def m12_sparse_bigbird(self, W=1):

Chapter 12 (Zaheer 2020): Local window + global tokens (token 0 attends to all and all attend to it). Random connections omitted for clarity.

154def m13_flash(self):

Chapter 13 (Dao 2022): Exact same math as #01 but uses IO-aware tiling on GPU. 3–8× faster, identical output. Hardware optimization, not algorithmic change.

157def m14_differential(self, lam=0.5):

Chapter 14 (Microsoft 2024): Compute two attention maps A₁, A₂ from split Q/K halves. Subtract: max(A₁ - λA₂, 0). Cancels noise, sharpens signal.

166def m15_mla(self, d_c=2):

Chapter 15 (DeepSeek-V2 2024): Compress K and V into shared latent c_KV of dimension d_c. Cache only d_c floats per token instead of (d_k+d_v). Reconstruct K′, V′ for attention.

180def compare_all(self):

Run all 15 mechanisms and print the output vector for token “cat” (row 1). This produces the comparison table shown in the text.

EXECUTION STATE
methods = List of (name, function) tuples for all 15 mechanisms
197print row[0]:7.4f ...

Print each mechanism’s output for “cat” with 4 decimal places. The printed table matches the comparison table in the chapter.

200ac = AttentionComparison()

Create the comparison object with shared Q, K, V matrices.

201ac.compare_all()

Run all 15 mechanisms and print the comparison table.

EXECUTION STATE
⬆ output =
Method                   dim-0   dim-1   dim-2   dim-3
------------------------------------------------------------
01. Scaled Dot-Product   0.5179  0.0898  0.3595  0.1481
02. Multi-Head (H=2)     0.4555  0.0891  0.3241  0.2711
03. Causal (Masked)      0.8176  0.1824  0.0000  0.0000
04. Cross-Attention      0.4822  0.1205  0.3534  0.1986
05. MQA                  0.4555  0.0891  0.4291  0.1417
06. GQA (=MHA)           0.4555  0.0891  0.3241  0.2711
07. Rel. Pos. Bias       0.4800  0.1597  0.3091  0.0969
08. RoPE                 0.3939  0.0912  0.4989  0.1518
09. ALiBi                0.4351  0.2541  0.2703  0.0567
10. Linear               0.4175  0.1748  0.3981  0.1942
11. Sliding Window       0.5465  0.1220  0.3315  0.0000
12. Sparse BigBird       0.5465  0.1220  0.3315  0.0000
13. Flash Attention      0.5179  0.0898  0.3595  0.1481
14. Differential         0.4177  0.0402  0.5421  0.0000
15. MLA                  0.3726  0.6074  0.3726  0.6074
193 lines without explanation
1import numpy as np
2import math
3
4class AttentionComparison:
5    """
6    All 15 attention mechanisms on the shared example.
7    Tokens: ["The", "cat", "sat", "on", "mat"], d_model=4.
8    """
9
10    def __init__(self):
11        self.tokens = ["The", "cat", "sat", "on", "mat"]
12        self.Q = np.array([
13            [1.0, 0.0, 1.0, 0.0],
14            [0.0, 2.0, 0.0, 1.0],
15            [1.0, 1.0, 1.0, 0.0],
16            [0.0, 0.0, 1.0, 1.0],
17            [1.0, 0.0, 0.0, 1.0],
18        ])
19        self.K = np.array([
20            [0.0, 1.0, 0.0, 1.0],
21            [1.0, 0.0, 1.0, 0.0],
22            [1.0, 1.0, 0.0, 0.0],
23            [0.0, 0.0, 1.0, 1.0],
24            [1.0, 0.0, 0.5, 0.5],
25        ])
26        self.V = np.array([
27            [1.0, 0.0, 0.0, 0.0],
28            [0.0, 1.0, 0.0, 0.0],
29            [0.0, 0.0, 1.0, 0.0],
30            [0.0, 0.0, 0.0, 1.0],
31            [1.0, 0.0, 1.0, 0.0],
32        ])
33        self.N = 5
34        self.d = 4
35        self.scale = math.sqrt(self.d)
36
37    def _softmax(self, x: np.ndarray) -> np.ndarray:
38        s = x - np.max(x, axis=-1, keepdims=True)
39        e = np.exp(s)
40        return e / np.sum(e, axis=-1, keepdims=True)
41
42    def m01_scaled_dot_product(self) -> np.ndarray:
43        scores = self.Q @ self.K.T / self.scale
44        weights = self._softmax(scores)
45        return weights @ self.V
46
47    def m02_multi_head(self, H: int = 2) -> np.ndarray:
48        d_h = self.d // H
49        scale_h = math.sqrt(d_h)
50        heads = []
51        for h in range(H):
52            s = h * d_h
53            Qh = self.Q[:, s:s+d_h]
54            Kh = self.K[:, s:s+d_h]
55            Vh = self.V[:, s:s+d_h]
56            w = self._softmax(Qh @ Kh.T / scale_h)
57            heads.append(w @ Vh)
58        return np.hstack(heads)
59
60    def m03_causal(self) -> np.ndarray:
61        mask = np.triu(np.ones((self.N, self.N)) * -1e9, k=1)
62        scores = self.Q @ self.K.T / self.scale + mask
63        weights = self._softmax(scores)
64        return weights @ self.V
65
66    def m04_cross(self) -> np.ndarray:
67        Q_dec = np.array([
68            [0.5, 0.5, 0.5, 0.5],
69            [0.0, 1.0, 0.0, 1.0],
70            [1.0, 0.0, 1.0, 0.0],
71            [0.5, 0.5, 0.0, 0.0],
72            [0.0, 0.0, 0.5, 0.5],
73        ])
74        scores = Q_dec @ self.K.T / self.scale
75        weights = self._softmax(scores)
76        return weights @ self.V
77
78    def m05_multi_query(self, H: int = 2) -> np.ndarray:
79        d_h = self.d // H
80        scale_h = math.sqrt(d_h)
81        K_shared = self.K[:, :d_h]
82        V_shared = self.V[:, :d_h]
83        heads = []
84        for h in range(H):
85            s = h * d_h
86            Qh = self.Q[:, s:s+d_h]
87            w = self._softmax(Qh @ K_shared.T / scale_h)
88            heads.append(w @ V_shared)
89        return np.hstack(heads)
90
91    def m06_grouped_query(self, H: int = 2, G: int = 2) -> np.ndarray:
92        return self.m02_multi_head(H)
93
94    def m07_relative_pos_bias(self) -> np.ndarray:
95        bias = np.zeros((self.N, self.N))
96        for i in range(self.N):
97            for j in range(self.N):
98                bias[i, j] = -0.5 * abs(i - j)
99        scores = self.Q @ self.K.T / self.scale + bias
100        weights = self._softmax(scores)
101        return weights @ self.V
102
103    def m08_rope(self) -> np.ndarray:
104        def apply_rope(x):
105            N, d = x.shape
106            r = x.copy()
107            for p in range(d // 2):
108                theta = 1.0 / (10000.0 ** (2 * p / d))
109                for i in range(N):
110                    a = i * theta
111                    c, s = math.cos(a), math.sin(a)
112                    x1, x2 = x[i, 2*p], x[i, 2*p+1]
113                    r[i, 2*p] = x1 * c - x2 * s
114                    r[i, 2*p+1] = x1 * s + x2 * c
115            return r
116        Qr, Kr = apply_rope(self.Q), apply_rope(self.K)
117        scores = Qr @ Kr.T / self.scale
118        weights = self._softmax(scores)
119        return weights @ self.V
120
121    def m09_alibi(self, slope: float = 1.0) -> np.ndarray:
122        bias = np.zeros((self.N, self.N))
123        for i in range(self.N):
124            for j in range(self.N):
125                bias[i, j] = -slope * abs(i - j)
126        scores = self.Q @ self.K.T / self.scale + bias
127        weights = self._softmax(scores)
128        return weights @ self.V
129
130    def m10_linear(self) -> np.ndarray:
131        def elu_plus(x):
132            return np.where(x > 0, x + 1, np.exp(x))
133        phi_Q = elu_plus(self.Q)
134        phi_K = elu_plus(self.K)
135        KV = phi_K.T @ self.V
136        K_sum = phi_K.sum(axis=0)
137        out = np.zeros_like(self.V)
138        for i in range(self.N):
139            out[i] = (phi_Q[i] @ KV) / (phi_Q[i] @ K_sum)
140        return out
141
142    def m11_sliding_window(self, W: int = 1) -> np.ndarray:
143        mask = np.full((self.N, self.N), -1e9)
144        for i in range(self.N):
145            for j in range(max(0, i-W), min(self.N, i+W+1)):
146                mask[i, j] = 0.0
147        scores = self.Q @ self.K.T / self.scale + mask
148        weights = self._softmax(scores)
149        return weights @ self.V
150
151    def m12_sparse_bigbird(self, W: int = 1) -> np.ndarray:
152        mask = np.full((self.N, self.N), -1e9)
153        for i in range(self.N):
154            for j in range(max(0, i-W), min(self.N, i+W+1)):
155                mask[i, j] = 0.0
156        mask[0, :] = 0.0
157        mask[:, 0] = 0.0
158        scores = self.Q @ self.K.T / self.scale + mask
159        weights = self._softmax(scores)
160        return weights @ self.V
161
162    def m13_flash(self) -> np.ndarray:
163        return self.m01_scaled_dot_product()
164
165    def m14_differential(self, lam: float = 0.5) -> np.ndarray:
166        d_h = self.d // 2
167        sc = math.sqrt(d_h)
168        A1 = self._softmax(self.Q[:, :d_h] @ self.K[:, :d_h].T / sc)
169        A2 = self._softmax(self.Q[:, d_h:] @ self.K[:, d_h:].T / sc)
170        diff = np.maximum(A1 - lam * A2, 0)
171        sums = diff.sum(axis=-1, keepdims=True)
172        sums = np.where(sums == 0, 1, sums)
173        weights = diff / sums
174        return weights @ self.V
175
176    def m15_mla(self, d_c: int = 2) -> np.ndarray:
177        W_D = np.zeros((self.d, d_c))
178        for i in range(self.d):
179            W_D[i, i % d_c] = 0.7
180        W_U = np.zeros((d_c, self.d))
181        for i in range(d_c):
182            for j in range(self.d):
183                if j % d_c == i:
184                    W_U[i, j] = 0.7
185        c = self.K @ W_D
186        Kp = c @ W_U
187        Vp = c @ W_U
188        scores = self.Q @ Kp.T / self.scale
189        weights = self._softmax(scores)
190        return weights @ Vp
191
192    def compare_all(self):
193        methods = [
194            ("01. Scaled Dot-Product", self.m01_scaled_dot_product),
195            ("02. Multi-Head (H=2)", self.m02_multi_head),
196            ("03. Causal (Masked)", self.m03_causal),
197            ("04. Cross-Attention", self.m04_cross),
198            ("05. MQA", self.m05_multi_query),
199            ("06. GQA (=MHA)", self.m06_grouped_query),
200            ("07. Rel. Pos. Bias", self.m07_relative_pos_bias),
201            ("08. RoPE", self.m08_rope),
202            ("09. ALiBi", self.m09_alibi),
203            ("10. Linear", self.m10_linear),
204            ("11. Sliding Window", self.m11_sliding_window),
205            ("12. Sparse BigBird", self.m12_sparse_bigbird),
206            ("13. Flash Attention", self.m13_flash),
207            ("14. Differential", self.m14_differential),
208            ("15. MLA", self.m15_mla),
209        ]
210        print(f"{'Method':<24} {'dim-0':>7} {'dim-1':>7}"
211              f" {'dim-2':>7} {'dim-3':>7}")
212        print("-" * 60)
213        for name, fn in methods:
214            out = fn()
215            row = out[1]  # "cat" row
216            print(f"{name:<24} {row[0]:7.4f} {row[1]:7.4f}"
217                  f" {row[2]:7.4f} {row[3]:7.4f}")
218
219
220ac = AttentionComparison()
221ac.compare_all()

PyTorch Implementation

The PyTorch version uses F.softmax\texttt{F.softmax} and torch.Tensor\texttt{torch.Tensor} operations. In production, you would use F.scaled_dot_product_attention\texttt{F.scaled\_dot\_product\_attention} which auto-selects Flash Attention, SDPA, or Math backend based on your hardware.

PyTorch — Key Mechanisms Compared
🐍attention_comparison_pytorch.py
1import torch

PyTorch provides GPU-accelerated tensor operations. In production, all 15 mechanisms run on CUDA tensors for parallel computation.

2import torch.nn.functional as F

F.softmax is the PyTorch equivalent of our NumPy _softmax. It handles numerical stability and GPU acceleration automatically.

3import math

math.sqrt for the √d_k scaling factor.

5class AttentionComparisonPyTorch:

PyTorch version of the comparison class. Uses torch.Tensor instead of np.ndarray, and F.softmax instead of manual softmax.

11def __init__(self):

Same shared example matrices as the NumPy version, but stored as torch.Tensor. In production, these would be on GPU via .cuda().

EXECUTION STATE
⬇ self.N = 5 — number of tokens
⬇ self.d = 4 — embedding dimension
37def m01_scaled_dot_product(self):

Foundation: softmax(QKᵀ/√d) · V. In production PyTorch, use F.scaled_dot_product_attention() which auto-selects Flash Attention when available.

EXECUTION STATE
dim=-1 = Softmax along last axis — normalize each row independently
42def m02_multi_head(self, H=2):

Split into H heads along feature dimension. torch.cat concatenates head outputs. In nn.MultiheadAttention, this uses learned W_Q, W_K, W_V projections.

54def m03_causal(self):

torch.triu creates upper-triangular mask with -inf. This is equivalent to is_causal=True in F.scaled_dot_product_attention().

EXECUTION STATE
float("-inf") = Softmax(-∞) = 0, blocking future tokens
diagonal=1 = Start masking from 1 above main diagonal (keep diagonal)
61def m07_relative_pos_bias(self):

Broadcast trick: pos.unsqueeze(1) - pos.unsqueeze(0) creates the distance matrix. More efficient than nested loops.

67def m08_rope(self):

Rotary embeddings. In production, use the optimized rotary_emb from transformers library or flash-attn.

82def m09_alibi(self, slope=1.0):

ALiBi: linear distance penalty. In multi-head models, each head gets a different slope (geometric sequence).

90def m11_sliding_window(self, W=1):

Window mask. In Mistral, this uses a specialized CUDA kernel for efficient sliding window computation.

98def m14_differential(self, lam=0.5):

Differential attention: torch.clamp(A1 - λ·A2, min=0) removes noise. The .clamp(min=1e-9) on sums prevents division by zero.

110def m15_mla(self, d_c=2):

MLA compression and reconstruction. In DeepSeek-V2, W_D, W_U are learned via backpropagation and d_c is much larger than 2.

124def compare_all(self):

Run representative mechanisms and print cat-row outputs. Produces identical values to the NumPy implementation.

140ac.compare_all()

Execute the comparison. PyTorch outputs match NumPy to floating-point precision.

EXECUTION STATE
⬆ output =
01. SDP: [0.5179, 0.0898, 0.3595, 0.1481]
02. MHA: [0.4555, 0.0891, 0.3241, 0.2711]
03. Causal: [0.8176, 0.1824, 0.0000, 0.0000]
07. RelPos: [0.4800, 0.1597, 0.3091, 0.0969]
08. RoPE: [0.3939, 0.0912, 0.4989, 0.1518]
09. ALiBi: [0.4351, 0.2541, 0.2703, 0.0567]
11. Window: [0.5465, 0.1220, 0.3315, 0.0000]
14. Diff: [0.4177, 0.0402, 0.5421, 0.0000]
15. MLA: [0.3726, 0.6074, 0.3726, 0.6074]
132 lines without explanation
1import torch
2import torch.nn.functional as F
3import math
4
5class AttentionComparisonPyTorch:
6    """
7    All 15 attention mechanisms in PyTorch.
8    Same shared example: "The cat sat on the mat".
9    """
10
11    def __init__(self):
12        self.tokens = ["The", "cat", "sat", "on", "mat"]
13        self.Q = torch.tensor([
14            [1.0, 0.0, 1.0, 0.0],
15            [0.0, 2.0, 0.0, 1.0],
16            [1.0, 1.0, 1.0, 0.0],
17            [0.0, 0.0, 1.0, 1.0],
18            [1.0, 0.0, 0.0, 1.0],
19        ])
20        self.K = torch.tensor([
21            [0.0, 1.0, 0.0, 1.0],
22            [1.0, 0.0, 1.0, 0.0],
23            [1.0, 1.0, 0.0, 0.0],
24            [0.0, 0.0, 1.0, 1.0],
25            [1.0, 0.0, 0.5, 0.5],
26        ])
27        self.V = torch.tensor([
28            [1.0, 0.0, 0.0, 0.0],
29            [0.0, 1.0, 0.0, 0.0],
30            [0.0, 0.0, 1.0, 0.0],
31            [0.0, 0.0, 0.0, 1.0],
32            [1.0, 0.0, 1.0, 0.0],
33        ])
34        self.N, self.d = 5, 4
35        self.scale = math.sqrt(self.d)
36
37    def m01_scaled_dot_product(self) -> torch.Tensor:
38        scores = self.Q @ self.K.T / self.scale
39        weights = F.softmax(scores, dim=-1)
40        return weights @ self.V
41
42    def m02_multi_head(self, H: int = 2) -> torch.Tensor:
43        d_h = self.d // H
44        scale_h = math.sqrt(d_h)
45        heads = []
46        for h in range(H):
47            s = h * d_h
48            w = F.softmax(
49                self.Q[:, s:s+d_h] @ self.K[:, s:s+d_h].T / scale_h,
50                dim=-1
51            )
52            heads.append(w @ self.V[:, s:s+d_h])
53        return torch.cat(heads, dim=-1)
54
55    def m03_causal(self) -> torch.Tensor:
56        mask = torch.triu(
57            torch.full((self.N, self.N), float("-inf")), diagonal=1
58        )
59        scores = self.Q @ self.K.T / self.scale + mask
60        return F.softmax(scores, dim=-1) @ self.V
61
62    def m07_relative_pos_bias(self) -> torch.Tensor:
63        pos = torch.arange(self.N)
64        bias = -0.5 * torch.abs(pos.unsqueeze(1) - pos.unsqueeze(0)).float()
65        scores = self.Q @ self.K.T / self.scale + bias
66        return F.softmax(scores, dim=-1) @ self.V
67
68    def m08_rope(self) -> torch.Tensor:
69        def apply(x: torch.Tensor) -> torch.Tensor:
70            N, d = x.shape
71            r = x.clone()
72            for p in range(d // 2):
73                theta = 1.0 / (10000.0 ** (2 * p / d))
74                for i in range(N):
75                    a = i * theta
76                    c, s = math.cos(a), math.sin(a)
77                    r[i, 2*p] = x[i, 2*p] * c - x[i, 2*p+1] * s
78                    r[i, 2*p+1] = x[i, 2*p] * s + x[i, 2*p+1] * c
79            return r
80        Qr, Kr = apply(self.Q), apply(self.K)
81        scores = Qr @ Kr.T / self.scale
82        return F.softmax(scores, dim=-1) @ self.V
83
84    def m09_alibi(self, slope: float = 1.0) -> torch.Tensor:
85        pos = torch.arange(self.N)
86        bias = -slope * torch.abs(
87            pos.unsqueeze(1) - pos.unsqueeze(0)
88        ).float()
89        scores = self.Q @ self.K.T / self.scale + bias
90        return F.softmax(scores, dim=-1) @ self.V
91
92    def m11_sliding_window(self, W: int = 1) -> torch.Tensor:
93        mask = torch.full((self.N, self.N), float("-inf"))
94        for i in range(self.N):
95            lo, hi = max(0, i - W), min(self.N, i + W + 1)
96            mask[i, lo:hi] = 0.0
97        scores = self.Q @ self.K.T / self.scale + mask
98        return F.softmax(scores, dim=-1) @ self.V
99
100    def m14_differential(self, lam: float = 0.5) -> torch.Tensor:
101        d_h = self.d // 2
102        sc = math.sqrt(d_h)
103        A1 = F.softmax(
104            self.Q[:, :d_h] @ self.K[:, :d_h].T / sc, dim=-1
105        )
106        A2 = F.softmax(
107            self.Q[:, d_h:] @ self.K[:, d_h:].T / sc, dim=-1
108        )
109        diff = torch.clamp(A1 - lam * A2, min=0)
110        sums = diff.sum(dim=-1, keepdim=True).clamp(min=1e-9)
111        return (diff / sums) @ self.V
112
113    def m15_mla(self, d_c: int = 2) -> torch.Tensor:
114        W_D = torch.zeros(self.d, d_c)
115        for i in range(self.d):
116            W_D[i, i % d_c] = 0.7
117        W_U = torch.zeros(d_c, self.d)
118        for i in range(d_c):
119            for j in range(self.d):
120                if j % d_c == i:
121                    W_U[i, j] = 0.7
122        c = self.K @ W_D
123        Kp, Vp = c @ W_U, c @ W_U
124        scores = self.Q @ Kp.T / self.scale
125        return F.softmax(scores, dim=-1) @ Vp
126
127    def compare_all(self):
128        methods = [
129            ("01. SDP", self.m01_scaled_dot_product),
130            ("02. MHA", self.m02_multi_head),
131            ("03. Causal", self.m03_causal),
132            ("07. RelPos", self.m07_relative_pos_bias),
133            ("08. RoPE", self.m08_rope),
134            ("09. ALiBi", self.m09_alibi),
135            ("11. Window", self.m11_sliding_window),
136            ("14. Diff", self.m14_differential),
137            ("15. MLA", self.m15_mla),
138        ]
139        for name, fn in methods:
140            out = fn()
141            cat_row = out[1]
142            print(f"{name}: [{cat_row[0]:.4f}, "
143                  f"{cat_row[1]:.4f}, {cat_row[2]:.4f}, "
144                  f"{cat_row[3]:.4f}]")
145
146
147ac = AttentionComparisonPyTorch()
148ac.compare_all()

Key Takeaways

  1. All attention mechanisms modify the same base formula. The original softmax(QK/dk)V\text{softmax}(QK^\top/\sqrt{d_k})\,V has four places you can intervene: what goes into Q and K (position), what pattern is allowed (masking/sparsity), how scores become weights (softmax alternative), and what is stored in cache (compression).
  2. Flash Attention is not a mechanism — it is an implementation. It computes the exact same output as standard attention but 3–8×\times faster. Understanding the distinction between algorithmic changes and hardware optimizations is crucial.
  3. Modern systems stack mechanisms. LLaMA 3 uses Causal + GQA + RoPE + Flash simultaneously. Each mechanism handles a different concern: autoregressive ordering, memory efficiency, position encoding, and GPU throughput.
  4. The output tells you which tokens were attended to. Our identity-like V matrix makes this visible: dim-3 in the output is exactly the attention weight on "on" (whose V row is [0,0,0,1]). When a mechanism blocks "on" from being attended to, dim-3 drops to zero.
  5. The right choice depends on deployment constraints. For training: maximize quality (MHA + RoPE). For inference: minimize KV-cache (GQA or MLA). For long documents: use sparse patterns (Sliding Window, BigBird). No single mechanism is universally optimal.

Exercises

  1. Output prediction. Without running the code, predict what the output vector for "mat" (the last token) would be under Causal Attention (#3). Then verify with the comparison table. Explain why it matches the standard attention output.
  2. Memory calculation. A model has d=8192d = 8192, H=64H = 64 heads, 128K context, and uses FP16. Calculate the KV-cache size in GB for: (a) MHA, (b) GQA with G=8G = 8, (c) MQA, and (d) MLA with dc=512d_c = 512.
  3. Mechanism combination. Design an attention stack for a model that must (a) generate text autoregressively, (b) handle 256K context, (c) fit inference on a single A100 (80 GB). Which mechanisms would you combine and why?
  4. Differential advantage. Using the shared example, identify which token "cat" attends to most strongly under Differential Attention vs. Standard Attention. Calculate the ratio of the two max weights. Why does Differential Attention produce a sharper distribution?
  5. Code extension. Add a new method m16_hybrid\texttt{m16\_hybrid} to the AttentionComparison class that combines Causal masking + RoPE + Differential Attention. Run it on the shared example and compare the output to each individual mechanism.

References

  1. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017. [Chapters 1, 2, 4]
  2. Radford, A., et al. (2018). "Improving Language Understanding by Generative Pre-Training." OpenAI. [Chapter 3]
  3. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150. [Chapter 5]
  4. Ainslie, J., et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. [Chapter 6]
  5. Shaw, P., et al. (2018). "Self-Attention with Relative Position Representations." NAACL 2018. [Chapter 7]
  6. Raffel, C., et al. (2020). "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer." JMLR 2020. [Chapter 7]
  7. Su, J., et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864. [Chapter 8]
  8. Press, O., et al. (2022). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." ICLR 2022. [Chapter 9]
  9. Katharopoulos, A., et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." ICML 2020. [Chapter 10]
  10. Beltagy, I., et al. (2020). "Longformer: The Long-Document Transformer." arXiv:2004.05150. [Chapter 11]
  11. Zaheer, M., et al. (2020). "Big Bird: Transformers for Longer Sequences." NeurIPS 2020. [Chapter 12]
  12. Dao, T., et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. [Chapter 13]
  13. Ye, Z., et al. (2024). "Differential Transformer." Microsoft Research, arXiv:2410.05258. [Chapter 14]
  14. DeepSeek-AI (2024). "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model." arXiv:2405.04434. [Chapter 15]

You have now seen the same five tokens — "The cat sat on the mat" — processed 15 different ways. You understand not just what each mechanism computes, but why it was invented, what trade-off it makes, and when to reach for it in practice. The next time you encounter a transformer architecture that combines GQA + RoPE + Flash + Sliding Window, you will understand exactly what each piece contributes and why it was chosen.

Loading comments...