Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018). “Improving Language Understanding by Generative Pre-Training.” OpenAI.
Learning Objectives
After completing this chapter, you will be able to:
Explain why standard bidirectional attention is illegal for language generation and how causal masking enforces the autoregressive constraint.
Write the mathematical formula for causal attention and describe the role of every symbol, including the mask function, −∞, and the lower-triangular structure.
Compute causal attention weights and outputs by hand for our shared example “The cat sat on the mat,” and contrast them with the standard (bidirectional) results from Chapter 1.
Implement causal attention from scratch in both NumPy and PyTorch, including the production-ready F.scaled_dot_product_attention API with is_causal=True.
Describe how causal masking interacts with Flash Attention, KV-cache optimization, and prefix-LM architectures in modern systems like GPT-4, Claude, and Gemini.
Where this appears: Every autoregressive language model in production today — GPT-4, Claude, Gemini, LLaMA, Mistral, Command R — uses causal masking. It is also the basis for code generation (Codex, Copilot), music generation (MusicLM), time series forecasting (TimesFM), and any task where the model must predict the next element in a sequence.
The Real Problem
The Autoregressive Constraint
In Chapter 1, we computed scaled dot-product attention where every token could attend to every other token — including tokens that appear later in the sequence. This bidirectional view is perfect for tasks like sentence classification or named entity recognition, where the model has access to the complete input before making predictions.
But language generation is fundamentally different. When you type a prompt and GPT generates a response, it produces tokens one at a time, left to right. At the moment it is generating the 50th token, the 51st token does not exist yet. The model cannot look into the future because the future has not been written.
This is the autoregressive constraint: at position i, the model may only condition on positions 0,1,…,i. The probability of a sequence is factored as:
P(x1,x2,…,xN)=∏i=1NP(xi∣x1,x2,…,xi−1)
Each token's probability depends only on the tokens that came before it. This is the chain rule of probability applied to sequential generation, and it is the mathematical foundation of every autoregressive language model.
The Cheating Problem
During training, the model sees the entire sequence at once (for parallelism). If we used standard bidirectional attention, token i could attend to token i+1 — the very token it is supposed to predict. This creates a devastating shortcut: the model learns to copy the answer from the future instead of learning the patterns of language.
Consider our example. When predicting what comes after “The cat,” the model should learn that a verb is likely. But if “cat” can attend to “sat” during training, the model simply learns to read off the answer. At inference time, “sat” does not exist yet, and the model collapses — it was trained to cheat, and now the cheat sheet is gone.
Think of it like a student taking an exam. If they can see the answer key during practice tests, they will ace every practice test. But when the real exam arrives without an answer key, they fail completely because they never learned the material — they only learned to copy. Causal masking is the exam proctor that hides the answer key during training, forcing the model to genuinely learn how language works.
The Core Insight: Causal masking does not add new computation. It removes information by blocking future positions. This single constraint is what makes autoregressive language generation possible. Without it, you cannot train a model to generate text.
The Story Behind Causal Masking
GPT and the Decoder-Only Model
The idea of masking future tokens in self-attention was introduced in the original Transformer paper (Vaswani et al., 2017), where the decoder used “masked self-attention” to preserve the autoregressive property. But it was Alec Radford's GPT (2018) that demonstrated the full power of this approach as a standalone architecture.
Radford's key insight was radical in its simplicity: you do not need an encoder at all. While BERT (Devlin et al., 2018) achieved remarkable results with bidirectional attention and masked language modeling, GPT showed that a decoder-only model — using nothing but causal self-attention — could learn powerful language representations through the simple objective of next-token prediction.
The original GPT used 12 transformer decoder layers, each with causal self-attention. The training objective was straightforward: given a sequence of tokens x1,…,xN, maximize the log-likelihood:
L=∑i=1NlogP(xi∣x1,…,xi−1;θ)
This was the birth of the paradigm that would scale to GPT-2 (1.5B parameters), GPT-3 (175B), GPT-4, and eventually Claude, Gemini, and LLaMA. The architecture is identical in all of them: stacked layers of causal self-attention followed by feedforward networks. The only things that changed were scale, data, and training recipes.
The remarkable consequence is that the single most important mechanism in modern AI — the one powering every chatbot, code assistant, and reasoning engine — is not a complex invention. It is standard attention with a single modification: set the upper triangle to −∞ before softmax.
The Mathematical Definition
Causal attention is identical to the scaled dot-product attention from Chapter 1, with one additional step: a mask is applied to the score matrix before softmax.
CausalAttn(Q,K,V)=softmax(mask(dkQK⊤))V
where the mask function is defined as:
mask(S)ij={Sij−∞if j≤i(past or present)if j>i(future — blocked)
Symbol-by-Symbol Breakdown
Symbol
Shape
Meaning
Q
(N,dk)
Query matrix. Each row is what a token is looking for.
K
(N,dk)
Key matrix. Each row is what a token advertises.
V
(N,dv)
Value matrix. Each row is the content to retrieve.
K⊤
(dk,N)
K transposed. Needed for the matrix product.
QK⊤
(N,N)
Raw score matrix. Entry (i,j) = dot(Q_i, K_j).
dk
scalar
Scaling factor. Keeps softmax gradients healthy.
mask(⋅)
(N,N)→(N,N)
Sets upper triangle (j > i) to −∞.
−∞
scalar
Negative infinity. exp(−∞)=0 in softmax.
softmax
(N,N)→(N,N)
Row-wise normalization. Each row sums to 1.
The Mask Matrix
The causal mask is a binary lower-triangular matrix. Entry (i,j) is 1 if token i is allowed to attend to token j, and 0 otherwise. The rule is simple: a token can see itself and everything before it.
The
cat
sat
on
mat
The
1
0
0
0
0
cat
1
1
0
0
0
sat
1
1
1
0
0
on
1
1
1
1
0
mat
1
1
1
1
1
This is a lower-triangular matrix: everything on or below the diagonal is 1 (visible), and everything above the diagonal is 0 (blocked). In implementation, we represent the blocked positions with a boolean mask and set those positions to −∞ in the score matrix. The softmax then naturally assigns zero weight to those positions.
Notice that each successive row gains one more visible token. “The” can only see itself. “cat” can see “The” and itself. “mat” (the last token) can see the entire sequence — for the last token, causal attention is identical to standard attention.
Interactive: Causal Mask Explorer
The interactive visualization below lets you explore the causal mask. Click on any token to see what it can attend to and the corresponding attention weights. Use the step slider to watch the mask build token by token. Toggle “Weights” to see the actual attention weight values instead of the binary mask.
Loading mask explorer...
Step-by-Step Calculation
We use the same Q, K, V matrices from Chapter 1. The first two steps (dot products and scaling) are identical. Causal masking changes only Step 3 — everything else follows mechanically.
Step 1: Raw Dot Products (QK⊤)
Compute the 5×5 matrix of all pairwise dot products between queries and keys. This is identical to Chapter 1:
The
cat
sat
on
mat
The
0.00
2.00
1.00
1.00
1.50
cat
3.00
0.00
2.00
1.00
0.50
sat
1.00
2.00
2.00
1.00
1.50
on
1.00
1.00
0.00
2.00
1.00
mat
1.00
1.00
1.00
1.00
1.50
Step 2: Scaling (÷dk=÷2.0)
Divide every entry by 4=2.0:
The
cat
sat
on
mat
The
0.000
1.000
0.500
0.500
0.750
cat
1.500
0.000
1.000
0.500
0.250
sat
0.500
1.000
1.000
0.500
0.750
on
0.500
0.500
0.000
1.000
0.500
mat
0.500
0.500
0.500
0.500
0.750
Step 3: Apply Causal Mask
This is the new step. Set every entry where j>i (future positions) to −∞:
The
cat
sat
on
mat
The
0.000
−∞
−∞
−∞
−∞
cat
1.500
0.000
−∞
−∞
−∞
sat
0.500
1.000
1.000
−∞
−∞
on
0.500
0.500
0.000
1.000
−∞
mat
0.500
0.500
0.500
0.500
0.750
The lower triangle and diagonal retain their original scaled values. The upper triangle is now −∞. When we apply softmax next, exp(−∞)=0, so these positions will receive exactly zero attention weight.
Step 4: Softmax Over Visible Tokens Only
Softmax is applied row-by-row. Because the −∞ entries contribute zero to both numerator and denominator, the softmax effectively operates over only the visible tokens in each row:
Token
Visible entries
Softmax computation
Weights
The
[0.000]
e0e0=1.0
[1.0000, 0, 0, 0, 0]
cat
[1.500,0.000]
e1.5+e0e1.5=5.4824.482
[0.8176, 0.1824, 0, 0, 0]
sat
[0.500,1.000,1.000]
e0.5+e1+e1e0.5=7.0851.649
[0.2327, 0.3837, 0.3837, 0, 0]
on
[0.500,0.500,0.000,1.000]
e0.5+e0.5+e0+e1e1=7.0162.718
[0.2350, 0.2350, 0.1425, 0.3875, 0]
mat
[0.500,0.500,0.500,0.500,0.750]
Full row (no masking for last token)
[0.1892, 0.1892, 0.1892, 0.1892, 0.2430]
Key observations:
“The” gets weight 1.0 on itself because it is the only visible token. Its representation is entirely self-referential.
“cat” allocates 81.8% of its attention to “The” and only 18.2% to itself. This is because Qcat⋅KThe has a high score (1.500 after scaling).
“mat” (the last token) sees the entire sequence, so its causal weights are identical to its standard attention weights from Chapter 1.
Step 5: Weighted Sum of Values
Each output row is a weighted sum of the visible value vectors:
Notice that dimensions 2 and 3 of “cat”'s output are exactly zero. In standard attention (Chapter 1), they were 0.3018 and 0.2058 respectively. Those values came from “sat” and “on” through the value vectors — but causal masking blocked “cat” from seeing those tokens entirely.
Worked Example: What “sat” Sees
Let us trace through the full calculation for “sat” (position 2), which can see “The” (position 0), “cat” (position 1), and itself (position 2).
Qsat=[1.0,1.0,1.0,0.0]
Pair
Dot product
Scaled
Status
Qsat⋅KThe
1×0+1×1+1×0+0×1=1
0.500
VISIBLE
Qsat⋅Kcat
1×1+1×0+1×1+0×0=2
1.000
VISIBLE
Qsat⋅Ksat
1×1+1×1+1×0+0×0=2
1.000
VISIBLE
Qsat⋅Kon
1×0+1×0+1×1+0×1=1
−∞
BLOCKED
Qsat⋅Kmat
1×1+1×0+1×0.5+0×0.5=1.5
−∞
BLOCKED
Softmax over the three visible entries [0.500,1.000,1.000]:
Compare to standard attention (Chapter 1):Osatstd=[0.2495,0.3481,0.3481,0.2495]. Dimension 3 is 0.2495 in standard attention (contributed by “on” and “mat” through their value vectors) but exactly 0.0 in causal attention because those future tokens are blocked.
Full Attention Weights and Output
Causal Attention Weight Matrix (5×5)
The
cat
sat
on
mat
The
1.0000
0.0000
0.0000
0.0000
0.0000
cat
0.8176
0.1824
0.0000
0.0000
0.0000
sat
0.2327
0.3837
0.3837
0.0000
0.0000
on
0.2350
0.2350
0.1425
0.3875
0.0000
mat
0.1892
0.1892
0.1892
0.1892
0.2430
Causal Output Matrix (5×4)
d0
d1
d2
d3
The
1.0000
0.0000
0.0000
0.0000
cat
0.8176
0.1824
0.0000
0.0000
sat
0.2327
0.3837
0.3837
0.0000
on
0.2350
0.2350
0.1425
0.3875
mat
0.3108
0.3108
0.3108
0.3108
Interactive: Standard vs Causal Attention
The visualization below shows the standard (bidirectional) and causal attention weights side by side. Hover over any cell to see exact values and the difference between the two modes. Switch to the “Output Vectors” view to see how the masked information loss propagates to the output.
Loading comparison...
Interactive: Autoregressive Generation
Watch a language model generate our sentence token by token. At each step, the model can only see the tokens it has already generated. Toggle “Cheat Mode” to see what would happen if the model could peek at future tokens — and why that breaks everything.
Loading decoder demo...
Applications Across Domains
Language Generation
Every modern large language model — GPT-4, Claude, Gemini, LLaMA, Mistral, Qwen — uses causal self-attention as its core mechanism. The entire forward pass of these models is a stack of causal attention layers followed by feedforward networks. When Claude generates a response to your prompt, each new token is produced by attending causally to all previous tokens.
Code Generation
Code is a sequence, and code generation is autoregressive. Tools like GitHub Copilot, Cursor, and Claude's code generation use causal attention to predict the next token of code given the preceding context. The causal structure naturally captures the sequential dependencies in code: a variable must be declared before it is used, a function must be defined before it is called, and an import must appear before the module is referenced.
Time Series Forecasting
In time series modeling, the future is literally unknown at prediction time. Models like TimesFM (Google, 2024) and Lag-Llama apply causal attention to temporal sequences, ensuring that predictions at time t depend only on observations at times 1,2,…,t−1. The causal mask prevents the temporal leakage that would make forecasting trivially (and falsely) accurate during training.
Music and Audio Generation
MusicLM (Google, 2023), Jukebox (OpenAI, 2020), and AudioLM use causal attention to generate audio tokens sequentially. Music has a natural temporal structure: a note at beat 16 should be influenced by the melody at beats 1–15, but not by what comes at beat 17. Causal masking enforces this temporal causality, producing musically coherent output.
Connection to Modern Systems
Flash Attention with Causal Masks
Flash Attention (Dao et al., 2022) computes attention using IO-aware tiling to minimize GPU memory transfers. It provides a dedicated causal mode (is_causal=True) that is even faster than the general masked version because it can skip entire tiles in the upper triangle that would be masked to zero. This makes causal attention cheaper than standard attention in practice — roughly 50% fewer operations since the upper triangle is never computed at all.
KV-Cache and Causal Structure
The causal mask creates a critical property: the attention computation for position i depends only on keys and values at positions 0,…,i. This means that once we have computed K and V for a token, they never change — future tokens do not affect past computations.
This enables the KV-cache optimization: during autoregressive generation, we store the K and V tensors for all previously generated tokens. When generating token i+1, we only need to compute Qi+1, Ki+1, and Vi+1 for the new token and append them to the cache. The attention for the new token is computed against all cached keys and values. Without causal masking, the KV-cache would be invalid because earlier tokens' representations could change when new tokens are added.
Prefix-LM: A Hybrid Approach
Some models (e.g., T5, UL2, PaLM 2 in certain modes) use a prefix-LM strategy: the prompt (prefix) uses bidirectional attention (no mask), while the generated continuation uses causal attention. This gives the model full context understanding of the input while maintaining the autoregressive constraint for generation. Mathematically, the mask becomes:
Mij={10if i<P or j≤iotherwise
where P is the length of the prefix. This is a strictly more general mask than pure causal — it reduces to standard attention when P=N and to pure causal when P=0.
Complexity Analysis
Metric
Standard Attention
Causal Attention
Time complexity
O(N2d)
O(N2d) (same)
Memory complexity
O(N2)
O(N2) (same)
Actual FLOPs
100%
~50% (upper triangle skipped by Flash Attention)
KV-cache compatible
No
Yes
Extra parameters
0
0 (mask is computed, not learned)
The theoretical complexity is unchanged because the mask is O(N2) to construct. But in practice, Flash Attention's causal mode achieves close to a 2x speedup by never computing the masked tiles. The mask itself requires zero extra parameters — it is purely a function of the sequence positions.
Python Implementation
The complete NumPy implementation below includes the CausalAttention class with all five steps separated into individual methods, plus a detailed explain() method that traces the computation for any query token. Click on any line of code to see the exact values flowing through memory at that point.
Causal Attention \u2014 NumPy Implementation
🐍causal_attention.py
Explanation(34)
Code(131)
1import numpy as np
NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.
2import math
Standard library math module. We use math.sqrt() for the scaling factor.
4class CausalAttention
A self-contained class implementing causal (masked) self-attention. This is the attention mechanism used in GPT, Claude, Gemini, and all autoregressive language models. The only difference from standard attention is the mask applied before softmax.
14def __init__(self, d_k: int)
Constructor takes d_k (dimension of Q/K vectors). Pre-computes the scaling factor sqrt(d_k) once.
EXECUTION STATE
⬇ input: d_k = 4 (dimension of each query/key vector)
⬆ sets = self.d_k = 4, self.scale = sqrt(4) = 2.0
19self.d_k = d_k
Store d_k for later reference.
20self.scale = math.sqrt(d_k)
Pre-compute sqrt(d_k) = sqrt(4) = 2.0. Dividing scores by this prevents softmax saturation.
EXECUTION STATE
self.scale = 2.0
22def _softmax(self, x) → np.ndarray
Numerically stable softmax that correctly handles -inf values (from the causal mask). Operates row-by-row: each row sums to 1.0.
EXECUTION STATE
⬇ input: x (5×5) =
The cat sat on mat
The 0.000 -inf -inf -inf -inf
cat 1.500 0.000 -inf -inf -inf
sat 0.500 1.000 1.000 -inf -inf
on 0.500 0.500 0.000 1.000 -inf
mat 0.500 0.500 0.500 0.500 0.750
⬆ returns = np.ndarray (5,5) — attention weights, each row sums to 1.0
24x_safe = np.where(np.isfinite(x), x, -1e9)
Replace -inf with a very large negative number (-1e9). np.isfinite() returns False for -inf and inf. This lets np.max() work correctly without returning -inf.
EXECUTION STATE
np.isfinite(x) = True for real numbers, False for -inf entries
-1e9 = -1000000000.0 — large enough that exp(-1e9) ≈ 0
Subtract the row-wise maximum from each element. This is the numerical stability trick: exp(x - max) avoids overflow while producing identical softmax results.
EXECUTION STATE
axis=-1 = operate along columns (within each row). For a (5,5) matrix, find the max of each row independently.
keepdims=True = keep the reduced axis as size-1 dimension. Returns shape (5,1) not (5,), so broadcasting works: x(5×5) - max(5×1).
Element-wise exponentiation. Entries from masked positions have x_shifted ≈ -1e9, so exp(-1e9) ≈ 0. The max entry in each row has x_shifted = 0, so exp(0) = 1.
The cat sat on mat
The 1.0000 0.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000 0.0000
on 0.2350 0.2350 0.1425 0.3875 0.0000
mat 0.1892 0.1892 0.1892 0.1892 0.2430
29def build_causal_mask(self, N) → np.ndarray
Builds the upper-triangular boolean mask. True means the position is blocked (future token). This is the ONLY difference from standard attention.
EXECUTION STATE
⬇ input: N = 5 (number of tokens)
⬆ returns = np.ndarray (5,5) of bool — True where j > i
np.triu extracts the upper triangle. k=1 means start 1 diagonal above the main diagonal (exclude the diagonal itself, since a token should see itself).
EXECUTION STATE
np.triu(..., k=1) = k=0 would include diagonal. k=1 starts ABOVE diagonal, so diagonal (self-attention) is NOT masked.
⬆ return: mask (5×5) =
The cat sat on mat
The False True True True True
cat False False True True True
sat False False False True True
on False False False False True
mat False False False False False
33def apply_mask(self, scores, mask) → np.ndarray
Sets all True (blocked) positions to -inf. After softmax, exp(-inf) = 0, giving these positions exactly zero attention weight.
EXECUTION STATE
⬇ input: scores (5×5) =
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
⬆ returns = np.ndarray (5,5) with -inf in upper triangle
35masked_scores = scores.copy()
Make a copy to avoid modifying the input array in-place. The original scores are preserved.
36masked_scores[mask] = -np.inf
Boolean indexing: everywhere mask is True (j > i), set the score to -inf. This is the core operation of causal masking.
EXECUTION STATE
masked_scores (5×5) =
The cat sat on mat
The 0.000 -inf -inf -inf -inf
cat 1.500 0.000 -inf -inf -inf
sat 0.500 1.000 1.000 -inf -inf
on 0.500 0.500 0.000 1.000 -inf
mat 0.500 0.500 0.500 0.500 0.750
37return masked_scores
Return the masked score matrix. The lower triangle and diagonal retain their original values.
39def forward(self, Q, K, V)
Full forward pass: computes raw scores, scales, masks, applies softmax, and computes the weighted sum.
EXECUTION STATE
⬇ input: Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
⬇ input: K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = (weights (5×5), output (5×4))
52N = Q.shape[0]
Number of tokens in the sequence.
EXECUTION STATE
N = 5
53raw_scores = Q @ K.T
Matrix multiply Q (5×4) with K transposed (4×5). Each entry [i,j] is the dot product of query i with key j — how well token i's question matches token j's advertisement.
EXECUTION STATE
K.T = K transposed from (5×4) to (4×5)
raw_scores (5×5) =
The cat sat on mat
The 0.00 2.00 1.00 1.00 1.50
cat 3.00 0.00 2.00 1.00 0.50
sat 1.00 2.00 2.00 1.00 1.50
on 1.00 1.00 0.00 2.00 1.00
mat 1.00 1.00 1.00 1.00 1.50
54scaled_scores = raw_scores / self.scale
Divide every score by sqrt(d_k) = 2.0. This keeps the variance of scores at ~1, preventing softmax from saturating.
EXECUTION STATE
self.scale = 2.0
scaled_scores (5×5) =
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
55mask = self.build_causal_mask(N)
Build the 5×5 boolean mask. True in upper triangle (j > i), False elsewhere.
EXECUTION STATE
mask (5×5) =
The cat sat on mat
The False True True True True
cat False False True True True
sat False False False True True
on False False False False True
mat False False False False False
Set all upper-triangle entries to -inf. This is the ONLY step that differs from Chapter 1 (standard attention).
EXECUTION STATE
masked_scores (5×5) =
The cat sat on mat
The 0.000 -inf -inf -inf -inf
cat 1.500 0.000 -inf -inf -inf
sat 0.500 1.000 1.000 -inf -inf
on 0.500 0.500 0.000 1.000 -inf
mat 0.500 0.500 0.500 0.500 0.750
57weights = self._softmax(masked_scores)
Softmax over the masked scores. Blocked positions get weight 0.0000. Each row sums to 1.0.
EXECUTION STATE
weights (5×5) =
The cat sat on mat
The 1.0000 0.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000 0.0000
on 0.2350 0.2350 0.1425 0.3875 0.0000
mat 0.1892 0.1892 0.1892 0.1892 0.2430
58output = weights @ V
Multiply weights (5×5) by V (5×4). Each output row is a weighted sum of value vectors, using ONLY values from visible (past + present) tokens.
EXECUTION STATE
output (5×4) =
d0 d1 d2 d3
The 1.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000
on 0.2350 0.2350 0.1425 0.3875
mat 0.3108 0.3108 0.3108 0.3108
59return weights, output
Return both the attention weights (for visualization/debugging) and the output (fed to the next layer).
61def explain(self, Q, K, V, tokens, query_idx=0)
Utility method that prints a human-readable trace of the causal attention computation for one specific query token. Used for debugging and education.
87tokens = ["The", "cat", "sat", "on", "mat"]
The shared example sentence used throughout all 15 chapters. 5 tokens with d_k = 4.
EXECUTION STATE
tokens = ["The", "cat", "sat", "on", "mat"]
N = 5 tokens
89Q = np.array([...])
Query matrix (5×4). Each row is what that token is 'asking for' — the features it wants to find in the sequence.
EXECUTION STATE
Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
97K = np.array([...])
Key matrix (5×4). Each row is what that token 'advertises' about itself — the features it exposes for matching.
EXECUTION STATE
K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
105V = np.array([...])
Value matrix (5×4). Each row is the content that token contributes when selected by attention. Q and K decide WHO to attend to; V decides WHAT to retrieve.
EXECUTION STATE
V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
114attn = CausalAttention(d_k=4)
Instantiate the causal attention module with d_k=4. This sets self.scale = sqrt(4) = 2.0.
EXECUTION STATE
attn.d_k = 4
attn.scale = 2.0
115weights, output = attn.forward(Q, K, V)
Run the full causal attention forward pass. Returns the 5×5 weight matrix and 5×4 output matrix.
EXECUTION STATE
weights (5×5) =
The cat sat on mat
The 1.0000 0.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000 0.0000
on 0.2350 0.2350 0.1425 0.3875 0.0000
mat 0.1892 0.1892 0.1892 0.1892 0.2430
output (5×4) =
d0 d1 d2 d3
The 1.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000
on 0.2350 0.2350 0.1425 0.3875
mat 0.3108 0.3108 0.3108 0.3108
121attn.explain(Q, K, V, tokens, query_idx=2)
Print a detailed trace for 'sat' (row 2). This shows all 5 steps of the causal attention computation for this specific token.
EXECUTION STATE
query_idx = 2 → 'sat'
visible tokens = ['The', 'cat', 'sat']
blocked tokens = ['on', 'mat']
97 lines without explanation
1import numpy as np
2import math
34classCausalAttention:5"""
6 Causal (Masked) Self-Attention (Radford et al., 2018)
78 Attention(Q, K, V) = softmax(mask(Q @ K^T / sqrt(d_k))) @ V
910 The upper triangle of the score matrix is set to -inf
11 before softmax, preventing tokens from attending to
12 future positions in the sequence.
13 """1415def__init__(self, d_k:int):16"""
17 Args:
18 d_k: Dimension of query/key vectors (used for scaling)
19 """20 self.d_k = d_k
21 self.scale = math.sqrt(d_k)2223def_softmax(self, x: np.ndarray)-> np.ndarray:24"""Numerically stable softmax that handles -inf values."""25 x_safe = np.where(np.isfinite(x), x,-1e9)26 x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)27 exp_x = np.exp(x_shifted)28return exp_x / np.sum(exp_x, axis=-1, keepdims=True)2930defbuild_causal_mask(self, N:int)-> np.ndarray:31"""Build upper-triangular boolean mask (True = blocked)."""32return np.triu(np.ones((N, N), dtype=bool), k=1)3334defapply_mask(self, scores: np.ndarray, mask: np.ndarray)-> np.ndarray:35"""Set masked (future) positions to -inf."""36 masked_scores = scores.copy()37 masked_scores[mask]=-np.inf
38return masked_scores
3940defforward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):41"""
42 Full forward pass of causal attention.
4344 Args:
45 Q: Query matrix (N, d_k)
46 K: Key matrix (N, d_k)
47 V: Value matrix (N, d_v)
4849 Returns:
50 weights: Causal attention weight matrix (N, N)
51 output: Context-enriched output (N, d_v)
52 """53 N = Q.shape[0]54 raw_scores = Q @ K.T
55 scaled_scores = raw_scores / self.scale
56 mask = self.build_causal_mask(N)57 masked_scores = self.apply_mask(scaled_scores, mask)58 weights = self._softmax(masked_scores)59 output = weights @ V
60return weights, output
6162defexplain(self, Q, K, V, tokens, query_idx=0):63"""Print a detailed trace for a specific query token."""64 N = Q.shape[0]65 raw = Q @ K.T
66 scaled = raw / self.scale
67 mask = self.build_causal_mask(N)68 masked = self.apply_mask(scaled, mask)69 weights = self._softmax(masked)70 output = weights @ V
7172 t = tokens[query_idx]73print(f"\n=== Causal attention trace for '{t}' (row {query_idx}) ===")74print(f"Q[{query_idx}] = {Q[query_idx]}")75print(f"\nVisible tokens: {tokens[:query_idx+1]}")76print(f"Blocked tokens: {tokens[query_idx+1:]}")7778print(f"\nStep 1-2: Raw scores -> scaled (/ {self.scale:.1f}):")79for j, tk inenumerate(tokens):80 s ="BLOCKED"if j > query_idx elsef"{scaled[query_idx, j]:.4f}"81print(f" S[{t},{tk}] = {s}")8283print(f"\nStep 3-4: Softmax weights:")84for j, tk inenumerate(tokens):85 w = weights[query_idx, j]86 bar ='#'*int(w *40)87print(f" A[{t},{tk}] = {w:.4f} |{bar}|")8889print(f"\nStep 5: Output = weighted sum of V:")90print(f" O[{t}] = {output[query_idx]}")919293# ── Shared Example (same Q, K, V as every chapter) ──94tokens =["The","cat","sat","on","mat"]9596Q = np.array([97[1.0,0.0,1.0,0.0],# The98[0.0,2.0,0.0,1.0],# cat99[1.0,1.0,1.0,0.0],# sat100[0.0,0.0,1.0,1.0],# on101[1.0,0.0,0.0,1.0],# mat102])103104K = np.array([105[0.0,1.0,0.0,1.0],# The106[1.0,0.0,1.0,0.0],# cat107[1.0,1.0,0.0,0.0],# sat108[0.0,0.0,1.0,1.0],# on109[1.0,0.0,0.5,0.5],# mat110])111112V = np.array([113[1.0,0.0,0.0,0.0],# The114[0.0,1.0,0.0,0.0],# cat115[0.0,0.0,1.0,0.0],# sat116[0.0,0.0,0.0,1.0],# on117[0.5,0.5,0.5,0.5],# mat118])119120# ── Run ──121attn = CausalAttention(d_k=4)122weights, output = attn.forward(Q, K, V)123124print("Causal Attention Weights (5x5):")125print(np.round(weights,4))126127print("\nCausal Output (5x4):")128print(np.round(output,4))129130# Detailed trace for "sat"131attn.explain(Q, K, V, tokens, query_idx=2)
PyTorch Implementation
The PyTorch implementation below includes both a trainable module (with learned WQ,WK,WV projections) and a static method for our manual simulation. The final section demonstrates the production-readyF.scaled_dot_product_attention API withis_causal=True, which automatically uses Flash Attention when available.
Causal Attention \u2014 PyTorch Implementation
🐍causal_attention_torch.py
Explanation(37)
Code(88)
1import torch
PyTorch provides GPU-accelerated tensor operations and autograd for backpropagation.
PyTorch module for causal self-attention. Inherits nn.Module for parameter tracking and GPU support.
16def __init__(self, d_model, d_k)
Initialize with d_model (input dimension) and d_k (projection dimension). Creates three learned projection matrices.
EXECUTION STATE
⬇ input: d_model = embedding dimension (e.g. 512)
⬇ input: d_k = 4 (query/key dimension)
18self.d_k = d_k
Store d_k for scaling.
19self.W_Q = nn.Linear(d_model, d_k, bias=False)
Learned query projection. Transforms input from d_model to d_k dimensions. bias=False because the original transformer omits bias in attention projections.
EXECUTION STATE
W_Q.weight = shape (d_k, d_model) — learned during training
20self.W_K = nn.Linear(d_model, d_k, bias=False)
Learned key projection. Same shape as W_Q.
21self.W_V = nn.Linear(d_model, d_k, bias=False)
Learned value projection. Transforms inputs into the 'content' representation.
23def forward(self, x) → tuple
Forward pass with learned projections. Takes raw embeddings x, projects to Q/K/V, and applies causal attention.
EXECUTION STATE
⬇ input: x = torch.Tensor shape (N, d_model)
25N = x.size(0)
Number of tokens in the sequence.
EXECUTION STATE
N = 5
26Q = self.W_Q(x)
Project x through W_Q to get queries. Shape: (N, d_k).
27K = self.W_K(x)
Project x through W_K to get keys. Shape: (N, d_k).
28V = self.W_V(x)
Project x through W_V to get values. Shape: (N, d_k).
Scaled dot-product scores. K.transpose(-2, -1) swaps the last two dims. Divide by sqrt(d_k) = 2.0.
EXECUTION STATE
.transpose(-2, -1) = swap last two dimensions: (N, d_k) → (d_k, N). Equivalent to .T for 2D tensors.
30mask = torch.triu(..., diagonal=1)
Upper-triangular boolean mask. diagonal=1 means start 1 above the main diagonal, so the diagonal itself (self-attention) is NOT masked.
EXECUTION STATE
diagonal=1 = equivalent to NumPy k=1. Start masking 1 position above the main diagonal.
mask (5×5) =
The cat sat on mat
The False True True True True
cat False False True True True
sat False False False True True
on False False False False True
mat False False False False False
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
The cat sat on mat
The 1.0000 0.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000 0.0000
on 0.2350 0.2350 0.1425 0.3875 0.0000
mat 0.1892 0.1892 0.1892 0.1892 0.2430
48output = weights @ V
Final output: weighted sum of visible value vectors.
EXECUTION STATE
output (5×4) =
d0 d1 d2 d3
The 1.0000 0.0000 0.0000 0.0000
cat 0.8176 0.1824 0.0000 0.0000
sat 0.2327 0.3837 0.3837 0.0000
on 0.2350 0.2350 0.1425 0.3875
mat 0.3108 0.3108 0.3108 0.3108
49return output, weights
Return (output, weights) tuple.
53tokens = ["The", "cat", "sat", "on", "mat"]
Shared example sentence.
55Q = torch.tensor([...])
Query matrix as a PyTorch tensor. Same values as the NumPy version.
PyTorch 2.0+ built-in attention with is_causal=True. This is the recommended production API — it automatically fuses operations and uses Flash Attention when available.
is_causal=True = automatically applies causal mask internally. No need to build/pass the mask manually.
.squeeze(0) = remove the batch dimension: (1,5,4) → (5,4).
51 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
56classCausalAttention(nn.Module):7"""
8 Causal (Masked) Self-Attention in PyTorch.
910 Supports both:
11 1. Learnable projections (W_Q, W_K, W_V) for training
12 2. Pre-computed Q, K, V for manual simulation
13 """1415def__init__(self, d_model:int, d_k:int):16super().__init__()17 self.d_k = d_k
18 self.W_Q = nn.Linear(d_model, d_k, bias=False)19 self.W_K = nn.Linear(d_model, d_k, bias=False)20 self.W_V = nn.Linear(d_model, d_k, bias=False)2122defforward(self, x: torch.Tensor)->tuple:23"""Forward pass with learned projections."""24 N = x.size(0)25 Q = self.W_Q(x)26 K = self.W_K(x)27 V = self.W_V(x)28 scores = Q @ K.transpose(-2,-1)/ math.sqrt(self.d_k)29 mask = torch.triu(30 torch.ones(N, N, dtype=torch.bool, device=x.device),31 diagonal=132)33 scores = scores.masked_fill(mask,float('-inf'))34 weights = F.softmax(scores, dim=-1)35 output = weights @ V
36return output, weights
3738@staticmethod39defmanual(Q, K, V):40"""Run with pre-computed Q, K, V (no learned weights)."""41 d_k = K.size(-1)42 N = Q.size(0)43 scores = Q @ K.transpose(-2,-1)/ math.sqrt(d_k)44 mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)45 scores = scores.masked_fill(mask,float('-inf'))46 weights = F.softmax(scores, dim=-1)47 output = weights @ V
48return output, weights
495051# ── Manual simulation with shared example ──52tokens =["The","cat","sat","on","mat"]5354Q = torch.tensor([55[1.0,0.0,1.0,0.0],56[0.0,2.0,0.0,1.0],57[1.0,1.0,1.0,0.0],58[0.0,0.0,1.0,1.0],59[1.0,0.0,0.0,1.0],60])6162K = torch.tensor([63[0.0,1.0,0.0,1.0],64[1.0,0.0,1.0,0.0],65[1.0,1.0,0.0,0.0],66[0.0,0.0,1.0,1.0],67[1.0,0.0,0.5,0.5],68])6970V = torch.tensor([71[1.0,0.0,0.0,0.0],72[0.0,1.0,0.0,0.0],73[0.0,0.0,1.0,0.0],74[0.0,0.0,0.0,1.0],75[0.5,0.5,0.5,0.5],76])7778output, weights = CausalAttention.manual(Q, K, V)79print("Causal Weights:\n", weights.round(decimals=4))80print("Causal Output:\n", output.round(decimals=4))8182# Using PyTorch built-in (recommended for production)83with torch.no_grad():84 builtin_output = F.scaled_dot_product_attention(85 Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0),86 is_causal=True87).squeeze(0)88print("\nBuilt-in Output:\n", builtin_output.round(decimals=4))
Key Takeaways
Causal masking is the only difference from standard attention. It sets the upper triangle of the score matrix to −∞ before softmax. No new parameters, no new computation — just information removal.
It enforces the autoregressive constraint. Token i can only attend to tokens 0,1,…,i. This prevents the model from “cheating” by looking at future tokens during training.
Training and inference are consistent. Because the mask blocks future tokens during training (where the full sequence is available), the model learns to operate the same way it must at inference time (where future tokens do not exist).
The last token sees everything. For the last position in the sequence, causal attention is identical to standard attention. Earlier tokens are progressively more constrained.
It enables the KV-cache. Because past computations never change, keys and values can be cached and reused during generation, turning each step from O(N2) to O(N).
Flash Attention makes it faster, not slower. Despite the mask, Flash Attention's causal mode skips the upper triangle entirely, achieving ~50% fewer FLOPs than standard attention.
Exercises
Exercise 1: Compute “on”'s Output by Hand
Token “on” (position 3) can see “The”, “cat”, “sat”, and itself. Using the causal attention weights from this chapter, compute Oon by taking the weighted sum of the four visible value vectors. Verify that your result matches [0.2350,0.2350,0.1425,0.3875].
Exercise 2: What If “The” Could See “cat”?
In causal attention, “The” gets output [1,0,0,0] (it only sees itself). In standard attention, it gets [0.2254,0.4135,0.2964,0.2964]. Explain conceptually why this difference matters for a language model during training.
Exercise 3: Bidirectional Mask for Position 2
Suppose you wanted token “sat” to use standard (bidirectional) attention while all other tokens remain causal. Modify the mask matrix for this scenario and recompute the weights for row 2. How does it differ from pure causal?
Exercise 4: Implement Prefix-LM Masking
Write a function prefix_lm_mask(N, P) that returns a mask where positions 0,…,P−1 (the prefix) use bidirectional attention and positions P,…,N−1 use causal attention. Apply it to our example with P=2 (so “The” and “cat” are bidirectional, and “sat”, “on”, “mat” are causal).
Exercise 5: KV-Cache Simulation
Simulate autoregressive generation with KV-caching. Starting with only “The,” compute the output for each position as you add tokens one at a time. At each step, append the new K and V vectors to the cache and compute attention for only the new query against all cached keys. Verify that your results match the full causal attention output.
References
Vaswani, A. et al. (2017). “Attention Is All You Need.” Advances in NeurIPS, 30. arXiv:1706.03762.
Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018). “Improving Language Understanding by Generative Pre-Training.” OpenAI Technical Report.
Radford, A. et al. (2019). “Language Models are Unsupervised Multitask Learners.” OpenAI Technical Report (GPT-2).
Brown, T. B. et al. (2020). “Language Models are Few-Shot Learners.” Advances in NeurIPS, 33 (GPT-3). arXiv:2005.14165.
Devlin, J., Chang, M., Lee, K., & Toutanova, K. (2018). “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv:1810.04805.
Dao, T., Fu, D., Ermon, S., Rudra, A., & Re, C. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” Advances in NeurIPS, 35. arXiv:2205.14135.
Tay, Y. et al. (2023). “UL2: Unifying Language Learning Paradigms.” ICLR 2023. arXiv:2205.05131.
Das, A. et al. (2024). “A Decoder-only Foundation Model for Time-Series Forecasting.”ICML 2024 (TimesFM). arXiv:2310.10688.