Chapter 3
40 min read
Section 4 of 17

Causal (Masked) Self-Attention

Causal (Masked) Self-Attention

Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018). “Improving Language Understanding by Generative Pre-Training.” OpenAI.


Learning Objectives

After completing this chapter, you will be able to:

  1. Explain why standard bidirectional attention is illegal for language generation and how causal masking enforces the autoregressive constraint.
  2. Write the mathematical formula for causal attention and describe the role of every symbol, including the mask function, -\infty, and the lower-triangular structure.
  3. Compute causal attention weights and outputs by hand for our shared example “The cat sat on the mat,” and contrast them with the standard (bidirectional) results from Chapter 1.
  4. Implement causal attention from scratch in both NumPy and PyTorch, including the production-ready F.scaled_dot_product_attention\texttt{F.scaled\_dot\_product\_attention} API with is_causal=True\texttt{is\_causal=True}.
  5. Describe how causal masking interacts with Flash Attention, KV-cache optimization, and prefix-LM architectures in modern systems like GPT-4, Claude, and Gemini.
Where this appears: Every autoregressive language model in production today — GPT-4, Claude, Gemini, LLaMA, Mistral, Command R — uses causal masking. It is also the basis for code generation (Codex, Copilot), music generation (MusicLM), time series forecasting (TimesFM), and any task where the model must predict the next element in a sequence.

The Real Problem

The Autoregressive Constraint

In Chapter 1, we computed scaled dot-product attention where every token could attend to every other token — including tokens that appear later in the sequence. This bidirectional view is perfect for tasks like sentence classification or named entity recognition, where the model has access to the complete input before making predictions.

But language generation is fundamentally different. When you type a prompt and GPT generates a response, it produces tokens one at a time, left to right. At the moment it is generating the 50th token, the 51st token does not exist yet. The model cannot look into the future because the future has not been written.

This is the autoregressive constraint: at position ii, the model may only condition on positions 0,1,,i0, 1, \ldots, i. The probability of a sequence is factored as:

P(x1,x2,,xN)=i=1NP(xix1,x2,,xi1)P(x_1, x_2, \ldots, x_N) = \prod_{i=1}^{N} P(x_i \mid x_1, x_2, \ldots, x_{i-1})

Each token's probability depends only on the tokens that came before it. This is the chain rule of probability applied to sequential generation, and it is the mathematical foundation of every autoregressive language model.


The Cheating Problem

During training, the model sees the entire sequence at once (for parallelism). If we used standard bidirectional attention, token ii could attend to token i+1i+1 — the very token it is supposed to predict. This creates a devastating shortcut: the model learns to copy the answer from the future instead of learning the patterns of language.

Consider our example. When predicting what comes after “The cat,” the model should learn that a verb is likely. But if “cat” can attend to “sat” during training, the model simply learns to read off the answer. At inference time, “sat” does not exist yet, and the model collapses — it was trained to cheat, and now the cheat sheet is gone.

Think of it like a student taking an exam. If they can see the answer key during practice tests, they will ace every practice test. But when the real exam arrives without an answer key, they fail completely because they never learned the material — they only learned to copy. Causal masking is the exam proctor that hides the answer key during training, forcing the model to genuinely learn how language works.

The Core Insight: Causal masking does not add new computation. It removes information by blocking future positions. This single constraint is what makes autoregressive language generation possible. Without it, you cannot train a model to generate text.

The Story Behind Causal Masking

GPT and the Decoder-Only Model

The idea of masking future tokens in self-attention was introduced in the original Transformer paper (Vaswani et al., 2017), where the decoder used “masked self-attention” to preserve the autoregressive property. But it was Alec Radford's GPT (2018) that demonstrated the full power of this approach as a standalone architecture.

Radford's key insight was radical in its simplicity: you do not need an encoder at all. While BERT (Devlin et al., 2018) achieved remarkable results with bidirectional attention and masked language modeling, GPT showed that a decoder-only model — using nothing but causal self-attention — could learn powerful language representations through the simple objective of next-token prediction.

The original GPT used 12 transformer decoder layers, each with causal self-attention. The training objective was straightforward: given a sequence of tokens x1,,xNx_1, \ldots, x_N, maximize the log-likelihood:

L=i=1NlogP(xix1,,xi1;θ)\mathcal{L} = \sum_{i=1}^{N} \log P(x_i \mid x_1, \ldots, x_{i-1}; \theta)

This was the birth of the paradigm that would scale to GPT-2 (1.5B parameters), GPT-3 (175B), GPT-4, and eventually Claude, Gemini, and LLaMA. The architecture is identical in all of them: stacked layers of causal self-attention followed by feedforward networks. The only things that changed were scale, data, and training recipes.

The remarkable consequence is that the single most important mechanism in modern AI — the one powering every chatbot, code assistant, and reasoning engine — is not a complex invention. It is standard attention with a single modification: set the upper triangle to -\infty before softmax.


The Mathematical Definition

Causal attention is identical to the scaled dot-product attention from Chapter 1, with one additional step: a mask is applied to the score matrix before softmax.

CausalAttn(Q,K,V)=softmax ⁣(mask ⁣(QKdk))V\text{CausalAttn}(Q, K, V) = \text{softmax}\!\left(\text{mask}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)\right) V

where the mask function is defined as:

mask(S)ij={Sijif ji(past or present)if j>i(future — blocked)\text{mask}(S)_{ij} = \begin{cases} S_{ij} & \text{if } j \leq i \quad \text{(past or present)} \\ -\infty & \text{if } j > i \quad \text{(future --- blocked)} \end{cases}

Symbol-by-Symbol Breakdown

SymbolShapeMeaning
QQ(N,dk)(N, d_k)Query matrix. Each row is what a token is looking for.
KK(N,dk)(N, d_k)Key matrix. Each row is what a token advertises.
VV(N,dv)(N, d_v)Value matrix. Each row is the content to retrieve.
KK^\top(dk,N)(d_k, N)K transposed. Needed for the matrix product.
QKQK^\top(N,N)(N, N)Raw score matrix. Entry (i,j) = dot(Q_i, K_j).
dk\sqrt{d_k}scalarScaling factor. Keeps softmax gradients healthy.
mask()\text{mask}(\cdot)(N,N)(N,N)(N, N) \to (N, N)Sets upper triangle (j > i) to -\infty.
-\inftyscalarNegative infinity. exp()=0\exp(-\infty) = 0 in softmax.
softmax(N,N)(N,N)(N, N) \to (N, N)Row-wise normalization. Each row sums to 1.

The Mask Matrix

The causal mask is a binary lower-triangular matrix. Entry (i,j)(i, j) is 1 if token ii is allowed to attend to token jj, and 0 otherwise. The rule is simple: a token can see itself and everything before it.

Thecatsatonmat
The10000
cat11000
sat11100
on11110
mat11111

This is a lower-triangular matrix: everything on or below the diagonal is 1 (visible), and everything above the diagonal is 0 (blocked). In implementation, we represent the blocked positions with a boolean mask and set those positions to -\infty in the score matrix. The softmax then naturally assigns zero weight to those positions.

Notice that each successive row gains one more visible token. “The” can only see itself. “cat” can see “The” and itself. “mat” (the last token) can see the entire sequence — for the last token, causal attention is identical to standard attention.


Interactive: Causal Mask Explorer

The interactive visualization below lets you explore the causal mask. Click on any token to see what it can attend to and the corresponding attention weights. Use the step slider to watch the mask build token by token. Toggle “Weights” to see the actual attention weight values instead of the binary mask.

Loading mask explorer...

Step-by-Step Calculation

We use the same Q, K, V matrices from Chapter 1. The first two steps (dot products and scaling) are identical. Causal masking changes only Step 3 — everything else follows mechanically.

Step 1: Raw Dot Products (QKQK^\top)

Compute the 5×55 \times 5 matrix of all pairwise dot products between queries and keys. This is identical to Chapter 1:

Thecatsatonmat
The0.002.001.001.001.50
cat3.000.002.001.000.50
sat1.002.002.001.001.50
on1.001.000.002.001.00
mat1.001.001.001.001.50

Step 2: Scaling (÷dk=÷2.0\div \sqrt{d_k} = \div 2.0)

Divide every entry by 4=2.0\sqrt{4} = 2.0:

Thecatsatonmat
The0.0001.0000.5000.5000.750
cat1.5000.0001.0000.5000.250
sat0.5001.0001.0000.5000.750
on0.5000.5000.0001.0000.500
mat0.5000.5000.5000.5000.750

Step 3: Apply Causal Mask

This is the new step. Set every entry where j>ij > i (future positions) to -\infty:

Thecatsatonmat
The0.000−∞−∞−∞−∞
cat1.5000.000−∞−∞−∞
sat0.5001.0001.000−∞−∞
on0.5000.5000.0001.000−∞
mat0.5000.5000.5000.5000.750

The lower triangle and diagonal retain their original scaled values. The upper triangle is now -\infty. When we apply softmax next, exp()=0\exp(-\infty) = 0, so these positions will receive exactly zero attention weight.

Step 4: Softmax Over Visible Tokens Only

Softmax is applied row-by-row. Because the -\infty entries contribute zero to both numerator and denominator, the softmax effectively operates over only the visible tokens in each row:

TokenVisible entriesSoftmax computationWeights
The[0.000][0.000]e0e0=1.0\frac{e^{0}}{e^{0}} = 1.0[1.0000, 0, 0, 0, 0]
cat[1.500,0.000][1.500, 0.000]e1.5e1.5+e0=4.4825.482\frac{e^{1.5}}{e^{1.5}+e^{0}} = \frac{4.482}{5.482}[0.8176, 0.1824, 0, 0, 0]
sat[0.500,1.000,1.000][0.500, 1.000, 1.000]e0.5e0.5+e1+e1=1.6497.085\frac{e^{0.5}}{e^{0.5}+e^{1}+e^{1}} = \frac{1.649}{7.085}[0.2327, 0.3837, 0.3837, 0, 0]
on[0.500,0.500,0.000,1.000][0.500, 0.500, 0.000, 1.000]e1e0.5 ⁣+ ⁣e0.5 ⁣+ ⁣e0 ⁣+ ⁣e1=2.7187.016\frac{e^{1}}{e^{0.5}\!+\!e^{0.5}\!+\!e^{0}\!+\!e^{1}} = \frac{2.718}{7.016}[0.2350, 0.2350, 0.1425, 0.3875, 0]
mat[0.500,0.500,0.500,0.500,0.750][0.500, 0.500, 0.500, 0.500, 0.750]Full row (no masking for last token)[0.1892, 0.1892, 0.1892, 0.1892, 0.2430]

Key observations:

  • “The” gets weight 1.0 on itself because it is the only visible token. Its representation is entirely self-referential.
  • “cat” allocates 81.8% of its attention to “The” and only 18.2% to itself. This is because QcatKTheQ_{\text{cat}} \cdot K_{\text{The}} has a high score (1.500 after scaling).
  • “mat” (the last token) sees the entire sequence, so its causal weights are identical to its standard attention weights from Chapter 1.

Step 5: Weighted Sum of Values

Each output row is a weighted sum of the visible value vectors:

Ocat=0.8176VThe+0.1824VcatO_{\text{cat}} = 0.8176 \cdot V_{\text{The}} + 0.1824 \cdot V_{\text{cat}}

=0.8176[1,0,0,0]+0.1824[0,1,0,0]=[0.8176,0.1824,0.0,0.0]= 0.8176 \cdot [1, 0, 0, 0] + 0.1824 \cdot [0, 1, 0, 0] = [0.8176, 0.1824, 0.0, 0.0]

Notice that dimensions 2 and 3 of “cat”'s output are exactly zero. In standard attention (Chapter 1), they were 0.3018 and 0.2058 respectively. Those values came from “sat” and “on” through the value vectors — but causal masking blocked “cat” from seeing those tokens entirely.


Worked Example: What “sat” Sees

Let us trace through the full calculation for “sat” (position 2), which can see “The” (position 0), “cat” (position 1), and itself (position 2).

Qsat=[1.0,1.0,1.0,0.0]Q_{\text{sat}} = [1.0, 1.0, 1.0, 0.0]

PairDot productScaledStatus
QsatKTheQ_{\text{sat}} \cdot K_{\text{The}}1×0+1×1+1×0+0×1=11{\times}0 + 1{\times}1 + 1{\times}0 + 0{\times}1 = 10.500VISIBLE
QsatKcatQ_{\text{sat}} \cdot K_{\text{cat}}1×1+1×0+1×1+0×0=21{\times}1 + 1{\times}0 + 1{\times}1 + 0{\times}0 = 21.000VISIBLE
QsatKsatQ_{\text{sat}} \cdot K_{\text{sat}}1×1+1×1+1×0+0×0=21{\times}1 + 1{\times}1 + 1{\times}0 + 0{\times}0 = 21.000VISIBLE
QsatKonQ_{\text{sat}} \cdot K_{\text{on}}1×0+1×0+1×1+0×1=11{\times}0 + 1{\times}0 + 1{\times}1 + 0{\times}1 = 1-\inftyBLOCKED
QsatKmatQ_{\text{sat}} \cdot K_{\text{mat}}1×1+1×0+1×0.5+0×0.5=1.51{\times}1 + 1{\times}0 + 1{\times}0.5 + 0{\times}0.5 = 1.5-\inftyBLOCKED

Softmax over the three visible entries [0.500,1.000,1.000][0.500, 1.000, 1.000]:

exp(0.5)=1.6487,exp(1.0)=2.7183,exp(1.0)=2.7183\exp(0.5) = 1.6487, \quad \exp(1.0) = 2.7183, \quad \exp(1.0) = 2.7183

Sum=1.6487+2.7183+2.7183=7.0853\text{Sum} = 1.6487 + 2.7183 + 2.7183 = 7.0853

Asat=[1.64877.0853,2.71837.0853,2.71837.0853,0,0]=[0.2327,0.3837,0.3837,0,0]A_{\text{sat}} = \left[\frac{1.6487}{7.0853}, \frac{2.7183}{7.0853}, \frac{2.7183}{7.0853}, 0, 0\right] = [0.2327, 0.3837, 0.3837, 0, 0]

Output:

Osat=0.2327[1,0,0,0]+0.3837[0,1,0,0]+0.3837[0,0,1,0]=[0.2327,0.3837,0.3837,0.0]O_{\text{sat}} = 0.2327 \cdot [1,0,0,0] + 0.3837 \cdot [0,1,0,0] + 0.3837 \cdot [0,0,1,0] = [0.2327, 0.3837, 0.3837, 0.0]

Compare to standard attention (Chapter 1): Osatstd=[0.2495,0.3481,0.3481,0.2495]O_{\text{sat}}^{\text{std}} = [0.2495, 0.3481, 0.3481, 0.2495]. Dimension 3 is 0.2495 in standard attention (contributed by “on” and “mat” through their value vectors) but exactly 0.0 in causal attention because those future tokens are blocked.

Full Attention Weights and Output

Causal Attention Weight Matrix (5×55 \times 5)

Thecatsatonmat
The1.00000.00000.00000.00000.0000
cat0.81760.18240.00000.00000.0000
sat0.23270.38370.38370.00000.0000
on0.23500.23500.14250.38750.0000
mat0.18920.18920.18920.18920.2430

Causal Output Matrix (5×45 \times 4)

d0d1d2d3
The1.00000.00000.00000.0000
cat0.81760.18240.00000.0000
sat0.23270.38370.38370.0000
on0.23500.23500.14250.3875
mat0.31080.31080.31080.3108

Interactive: Standard vs Causal Attention

The visualization below shows the standard (bidirectional) and causal attention weights side by side. Hover over any cell to see exact values and the difference between the two modes. Switch to the “Output Vectors” view to see how the masked information loss propagates to the output.

Loading comparison...

Interactive: Autoregressive Generation

Watch a language model generate our sentence token by token. At each step, the model can only see the tokens it has already generated. Toggle “Cheat Mode” to see what would happen if the model could peek at future tokens — and why that breaks everything.

Loading decoder demo...

Applications Across Domains

Language Generation

Every modern large language model — GPT-4, Claude, Gemini, LLaMA, Mistral, Qwen — uses causal self-attention as its core mechanism. The entire forward pass of these models is a stack of causal attention layers followed by feedforward networks. When Claude generates a response to your prompt, each new token is produced by attending causally to all previous tokens.

Code Generation

Code is a sequence, and code generation is autoregressive. Tools like GitHub Copilot, Cursor, and Claude's code generation use causal attention to predict the next token of code given the preceding context. The causal structure naturally captures the sequential dependencies in code: a variable must be declared before it is used, a function must be defined before it is called, and an import must appear before the module is referenced.

Time Series Forecasting

In time series modeling, the future is literally unknown at prediction time. Models like TimesFM (Google, 2024) and Lag-Llama apply causal attention to temporal sequences, ensuring that predictions at time tt depend only on observations at times 1,2,,t11, 2, \ldots, t-1. The causal mask prevents the temporal leakage that would make forecasting trivially (and falsely) accurate during training.

Music and Audio Generation

MusicLM (Google, 2023), Jukebox (OpenAI, 2020), and AudioLM use causal attention to generate audio tokens sequentially. Music has a natural temporal structure: a note at beat 16 should be influenced by the melody at beats 1–15, but not by what comes at beat 17. Causal masking enforces this temporal causality, producing musically coherent output.


Connection to Modern Systems

Flash Attention with Causal Masks

Flash Attention (Dao et al., 2022) computes attention using IO-aware tiling to minimize GPU memory transfers. It provides a dedicated causal mode (is_causal=True\texttt{is\_causal=True}) that is even faster than the general masked version because it can skip entire tiles in the upper triangle that would be masked to zero. This makes causal attention cheaper than standard attention in practice — roughly 50% fewer operations since the upper triangle is never computed at all.

KV-Cache and Causal Structure

The causal mask creates a critical property: the attention computation for position ii depends only on keys and values at positions 0,,i0, \ldots, i. This means that once we have computed K and V for a token, they never change — future tokens do not affect past computations.

This enables the KV-cache optimization: during autoregressive generation, we store the K and V tensors for all previously generated tokens. When generating token i+1i+1, we only need to compute Qi+1Q_{i+1}, Ki+1K_{i+1}, and Vi+1V_{i+1} for the new token and append them to the cache. The attention for the new token is computed against all cached keys and values. Without causal masking, the KV-cache would be invalid because earlier tokens' representations could change when new tokens are added.

Prefix-LM: A Hybrid Approach

Some models (e.g., T5, UL2, PaLM 2 in certain modes) use a prefix-LM strategy: the prompt (prefix) uses bidirectional attention (no mask), while the generated continuation uses causal attention. This gives the model full context understanding of the input while maintaining the autoregressive constraint for generation. Mathematically, the mask becomes:

Mij={1if i<P or ji0otherwiseM_{ij} = \begin{cases} 1 & \text{if } i < P \text{ or } j \leq i \\ 0 & \text{otherwise} \end{cases}

where PP is the length of the prefix. This is a strictly more general mask than pure causal — it reduces to standard attention when P=NP = N and to pure causal when P=0P = 0.


Complexity Analysis

MetricStandard AttentionCausal Attention
Time complexityO(N2d)O(N^2 d)O(N2d)O(N^2 d) (same)
Memory complexityO(N2)O(N^2)O(N2)O(N^2) (same)
Actual FLOPs100%~50% (upper triangle skipped by Flash Attention)
KV-cache compatibleNoYes
Extra parameters00 (mask is computed, not learned)

The theoretical complexity is unchanged because the mask is O(N2)O(N^2) to construct. But in practice, Flash Attention's causal mode achieves close to a 2x speedup by never computing the masked tiles. The mask itself requires zero extra parameters — it is purely a function of the sequence positions.


Python Implementation

The complete NumPy implementation below includes the CausalAttention\texttt{CausalAttention} class with all five steps separated into individual methods, plus a detailed explain()\texttt{explain()} method that traces the computation for any query token. Click on any line of code to see the exact values flowing through memory at that point.

Causal Attention \u2014 NumPy Implementation
🐍causal_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.

2import math

Standard library math module. We use math.sqrt() for the scaling factor.

4class CausalAttention

A self-contained class implementing causal (masked) self-attention. This is the attention mechanism used in GPT, Claude, Gemini, and all autoregressive language models. The only difference from standard attention is the mask applied before softmax.

14def __init__(self, d_k: int)

Constructor takes d_k (dimension of Q/K vectors). Pre-computes the scaling factor sqrt(d_k) once.

EXECUTION STATE
⬇ input: d_k = 4 (dimension of each query/key vector)
⬆ sets = self.d_k = 4, self.scale = sqrt(4) = 2.0
19self.d_k = d_k

Store d_k for later reference.

20self.scale = math.sqrt(d_k)

Pre-compute sqrt(d_k) = sqrt(4) = 2.0. Dividing scores by this prevents softmax saturation.

EXECUTION STATE
self.scale = 2.0
22def _softmax(self, x) → np.ndarray

Numerically stable softmax that correctly handles -inf values (from the causal mask). Operates row-by-row: each row sums to 1.0.

EXECUTION STATE
⬇ input: x (5×5) =
       The     cat     sat      on     mat
The  0.000    -inf    -inf    -inf    -inf
cat  1.500   0.000    -inf    -inf    -inf
sat  0.500   1.000   1.000    -inf    -inf
on   0.500   0.500   0.000   1.000    -inf
mat  0.500   0.500   0.500   0.500   0.750
⬆ returns = np.ndarray (5,5) — attention weights, each row sums to 1.0
24x_safe = np.where(np.isfinite(x), x, -1e9)

Replace -inf with a very large negative number (-1e9). np.isfinite() returns False for -inf and inf. This lets np.max() work correctly without returning -inf.

EXECUTION STATE
np.isfinite(x) = True for real numbers, False for -inf entries
-1e9 = -1000000000.0 — large enough that exp(-1e9) ≈ 0
25x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)

Subtract the row-wise maximum from each element. This is the numerical stability trick: exp(x - max) avoids overflow while producing identical softmax results.

EXECUTION STATE
axis=-1 = operate along columns (within each row). For a (5,5) matrix, find the max of each row independently.
keepdims=True = keep the reduced axis as size-1 dimension. Returns shape (5,1) not (5,), so broadcasting works: x(5×5) - max(5×1).
row maxima = [0.000, 1.500, 1.000, 1.000, 0.750]
── Row 0 (The) ── =
x_shifted[0] = [0.000, -1e9, -1e9, -1e9, -1e9]
── Row 1 (cat) ── =
x_shifted[1] = [0.000, -1.500, -1e9, -1e9, -1e9]
── Row 2 (sat) ── =
x_shifted[2] = [-0.500, 0.000, 0.000, -1e9, -1e9]
── Row 3 (on) ── =
x_shifted[3] = [-0.500, -0.500, -1.000, 0.000, -1e9]
── Row 4 (mat) ── =
x_shifted[4] = [-0.250, -0.250, -0.250, -0.250, 0.000]
26exp_x = np.exp(x_shifted)

Element-wise exponentiation. Entries from masked positions have x_shifted ≈ -1e9, so exp(-1e9) ≈ 0. The max entry in each row has x_shifted = 0, so exp(0) = 1.

EXECUTION STATE
── Row 0 (The) ── =
exp_x[0] = [1.0000, ≈0, ≈0, ≈0, ≈0]
── Row 1 (cat) ── =
exp_x[1] = [1.0000, 0.2231, ≈0, ≈0, ≈0]
── Row 2 (sat) ── =
exp_x[2] = [0.6065, 1.0000, 1.0000, ≈0, ≈0]
── Row 3 (on) ── =
exp_x[3] = [0.6065, 0.6065, 0.3679, 1.0000, ≈0]
── Row 4 (mat) ── =
exp_x[4] = [0.7788, 0.7788, 0.7788, 0.7788, 1.0000]
27return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each element by its row sum. This produces a probability distribution for each row. Blocked positions get weight ≈ 0.

EXECUTION STATE
row sums = [1.0000, 1.2231, 2.6065, 2.5809, 4.1152]
⬆ return: weights (5×5) =
       The     cat     sat      on     mat
The  1.0000  0.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000  0.0000
on   0.2350  0.2350  0.1425  0.3875  0.0000
mat  0.1892  0.1892  0.1892  0.1892  0.2430
29def build_causal_mask(self, N) → np.ndarray

Builds the upper-triangular boolean mask. True means the position is blocked (future token). This is the ONLY difference from standard attention.

EXECUTION STATE
⬇ input: N = 5 (number of tokens)
⬆ returns = np.ndarray (5,5) of bool — True where j > i
31return np.triu(np.ones((N, N), dtype=bool), k=1)

np.triu extracts the upper triangle. k=1 means start 1 diagonal above the main diagonal (exclude the diagonal itself, since a token should see itself).

EXECUTION STATE
np.triu(..., k=1) = k=0 would include diagonal. k=1 starts ABOVE diagonal, so diagonal (self-attention) is NOT masked.
⬆ return: mask (5×5) =
       The    cat    sat     on    mat
The  False   True   True   True   True
cat  False  False   True   True   True
sat  False  False  False   True   True
on   False  False  False  False   True
mat  False  False  False  False  False
33def apply_mask(self, scores, mask) → np.ndarray

Sets all True (blocked) positions to -inf. After softmax, exp(-inf) = 0, giving these positions exactly zero attention weight.

EXECUTION STATE
⬇ input: scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
⬇ input: mask (5×5) = Upper-triangular boolean (True = blocked)
⬆ returns = np.ndarray (5,5) with -inf in upper triangle
35masked_scores = scores.copy()

Make a copy to avoid modifying the input array in-place. The original scores are preserved.

36masked_scores[mask] = -np.inf

Boolean indexing: everywhere mask is True (j > i), set the score to -inf. This is the core operation of causal masking.

EXECUTION STATE
masked_scores (5×5) =
       The     cat     sat      on     mat
The  0.000    -inf    -inf    -inf    -inf
cat  1.500   0.000    -inf    -inf    -inf
sat  0.500   1.000   1.000    -inf    -inf
on   0.500   0.500   0.000   1.000    -inf
mat  0.500   0.500   0.500   0.500   0.750
37return masked_scores

Return the masked score matrix. The lower triangle and diagonal retain their original values.

39def forward(self, Q, K, V)

Full forward pass: computes raw scores, scales, masks, applies softmax, and computes the weighted sum.

EXECUTION STATE
⬇ input: Q (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
       d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (weights (5×5), output (5×4))
52N = Q.shape[0]

Number of tokens in the sequence.

EXECUTION STATE
N = 5
53raw_scores = Q @ K.T

Matrix multiply Q (5×4) with K transposed (4×5). Each entry [i,j] is the dot product of query i with key j — how well token i's question matches token j's advertisement.

EXECUTION STATE
K.T = K transposed from (5×4) to (4×5)
raw_scores (5×5) =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
54scaled_scores = raw_scores / self.scale

Divide every score by sqrt(d_k) = 2.0. This keeps the variance of scores at ~1, preventing softmax from saturating.

EXECUTION STATE
self.scale = 2.0
scaled_scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
55mask = self.build_causal_mask(N)

Build the 5×5 boolean mask. True in upper triangle (j > i), False elsewhere.

EXECUTION STATE
mask (5×5) =
       The    cat    sat     on    mat
The  False   True   True   True   True
cat  False  False   True   True   True
sat  False  False  False   True   True
on   False  False  False  False   True
mat  False  False  False  False  False
56masked_scores = self.apply_mask(scaled_scores, mask)

Set all upper-triangle entries to -inf. This is the ONLY step that differs from Chapter 1 (standard attention).

EXECUTION STATE
masked_scores (5×5) =
       The     cat     sat      on     mat
The  0.000    -inf    -inf    -inf    -inf
cat  1.500   0.000    -inf    -inf    -inf
sat  0.500   1.000   1.000    -inf    -inf
on   0.500   0.500   0.000   1.000    -inf
mat  0.500   0.500   0.500   0.500   0.750
57weights = self._softmax(masked_scores)

Softmax over the masked scores. Blocked positions get weight 0.0000. Each row sums to 1.0.

EXECUTION STATE
weights (5×5) =
       The     cat     sat      on     mat
The  1.0000  0.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000  0.0000
on   0.2350  0.2350  0.1425  0.3875  0.0000
mat  0.1892  0.1892  0.1892  0.1892  0.2430
58output = weights @ V

Multiply weights (5×5) by V (5×4). Each output row is a weighted sum of value vectors, using ONLY values from visible (past + present) tokens.

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  1.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000
on   0.2350  0.2350  0.1425  0.3875
mat  0.3108  0.3108  0.3108  0.3108
59return weights, output

Return both the attention weights (for visualization/debugging) and the output (fed to the next layer).

61def explain(self, Q, K, V, tokens, query_idx=0)

Utility method that prints a human-readable trace of the causal attention computation for one specific query token. Used for debugging and education.

87tokens = ["The", "cat", "sat", "on", "mat"]

The shared example sentence used throughout all 15 chapters. 5 tokens with d_k = 4.

EXECUTION STATE
tokens = ["The", "cat", "sat", "on", "mat"]
N = 5 tokens
89Q = np.array([...])

Query matrix (5×4). Each row is what that token is 'asking for' — the features it wants to find in the sequence.

EXECUTION STATE
Q (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
97K = np.array([...])

Key matrix (5×4). Each row is what that token 'advertises' about itself — the features it exposes for matching.

EXECUTION STATE
K (5×4) =
       d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
105V = np.array([...])

Value matrix (5×4). Each row is the content that token contributes when selected by attention. Q and K decide WHO to attend to; V decides WHAT to retrieve.

EXECUTION STATE
V (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
114attn = CausalAttention(d_k=4)

Instantiate the causal attention module with d_k=4. This sets self.scale = sqrt(4) = 2.0.

EXECUTION STATE
attn.d_k = 4
attn.scale = 2.0
115weights, output = attn.forward(Q, K, V)

Run the full causal attention forward pass. Returns the 5×5 weight matrix and 5×4 output matrix.

EXECUTION STATE
weights (5×5) =
       The     cat     sat      on     mat
The  1.0000  0.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000  0.0000
on   0.2350  0.2350  0.1425  0.3875  0.0000
mat  0.1892  0.1892  0.1892  0.1892  0.2430
output (5×4) =
       d0      d1      d2      d3
The  1.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000
on   0.2350  0.2350  0.1425  0.3875
mat  0.3108  0.3108  0.3108  0.3108
121attn.explain(Q, K, V, tokens, query_idx=2)

Print a detailed trace for 'sat' (row 2). This shows all 5 steps of the causal attention computation for this specific token.

EXECUTION STATE
query_idx = 2 → 'sat'
visible tokens = ['The', 'cat', 'sat']
blocked tokens = ['on', 'mat']
97 lines without explanation
1import numpy as np
2import math
3
4class CausalAttention:
5    """
6    Causal (Masked) Self-Attention (Radford et al., 2018)
7
8    Attention(Q, K, V) = softmax(mask(Q @ K^T / sqrt(d_k))) @ V
9
10    The upper triangle of the score matrix is set to -inf
11    before softmax, preventing tokens from attending to
12    future positions in the sequence.
13    """
14
15    def __init__(self, d_k: int):
16        """
17        Args:
18            d_k: Dimension of query/key vectors (used for scaling)
19        """
20        self.d_k = d_k
21        self.scale = math.sqrt(d_k)
22
23    def _softmax(self, x: np.ndarray) -> np.ndarray:
24        """Numerically stable softmax that handles -inf values."""
25        x_safe = np.where(np.isfinite(x), x, -1e9)
26        x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)
27        exp_x = np.exp(x_shifted)
28        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
29
30    def build_causal_mask(self, N: int) -> np.ndarray:
31        """Build upper-triangular boolean mask (True = blocked)."""
32        return np.triu(np.ones((N, N), dtype=bool), k=1)
33
34    def apply_mask(self, scores: np.ndarray, mask: np.ndarray) -> np.ndarray:
35        """Set masked (future) positions to -inf."""
36        masked_scores = scores.copy()
37        masked_scores[mask] = -np.inf
38        return masked_scores
39
40    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
41        """
42        Full forward pass of causal attention.
43
44        Args:
45            Q: Query matrix  (N, d_k)
46            K: Key matrix    (N, d_k)
47            V: Value matrix  (N, d_v)
48
49        Returns:
50            weights: Causal attention weight matrix (N, N)
51            output:  Context-enriched output         (N, d_v)
52        """
53        N = Q.shape[0]
54        raw_scores = Q @ K.T
55        scaled_scores = raw_scores / self.scale
56        mask = self.build_causal_mask(N)
57        masked_scores = self.apply_mask(scaled_scores, mask)
58        weights = self._softmax(masked_scores)
59        output = weights @ V
60        return weights, output
61
62    def explain(self, Q, K, V, tokens, query_idx=0):
63        """Print a detailed trace for a specific query token."""
64        N = Q.shape[0]
65        raw = Q @ K.T
66        scaled = raw / self.scale
67        mask = self.build_causal_mask(N)
68        masked = self.apply_mask(scaled, mask)
69        weights = self._softmax(masked)
70        output = weights @ V
71
72        t = tokens[query_idx]
73        print(f"\n=== Causal attention trace for '{t}' (row {query_idx}) ===")
74        print(f"Q[{query_idx}] = {Q[query_idx]}")
75        print(f"\nVisible tokens: {tokens[:query_idx+1]}")
76        print(f"Blocked tokens: {tokens[query_idx+1:]}")
77
78        print(f"\nStep 1-2: Raw scores -> scaled (/ {self.scale:.1f}):")
79        for j, tk in enumerate(tokens):
80            s = "BLOCKED" if j > query_idx else f"{scaled[query_idx, j]:.4f}"
81            print(f"  S[{t},{tk}] = {s}")
82
83        print(f"\nStep 3-4: Softmax weights:")
84        for j, tk in enumerate(tokens):
85            w = weights[query_idx, j]
86            bar = '#' * int(w * 40)
87            print(f"  A[{t},{tk}] = {w:.4f} |{bar}|")
88
89        print(f"\nStep 5: Output = weighted sum of V:")
90        print(f"  O[{t}] = {output[query_idx]}")
91
92
93# ── Shared Example (same Q, K, V as every chapter) ──
94tokens = ["The", "cat", "sat", "on", "mat"]
95
96Q = np.array([
97    [1.0, 0.0, 1.0, 0.0],   # The
98    [0.0, 2.0, 0.0, 1.0],   # cat
99    [1.0, 1.0, 1.0, 0.0],   # sat
100    [0.0, 0.0, 1.0, 1.0],   # on
101    [1.0, 0.0, 0.0, 1.0],   # mat
102])
103
104K = np.array([
105    [0.0, 1.0, 0.0, 1.0],   # The
106    [1.0, 0.0, 1.0, 0.0],   # cat
107    [1.0, 1.0, 0.0, 0.0],   # sat
108    [0.0, 0.0, 1.0, 1.0],   # on
109    [1.0, 0.0, 0.5, 0.5],   # mat
110])
111
112V = np.array([
113    [1.0, 0.0, 0.0, 0.0],   # The
114    [0.0, 1.0, 0.0, 0.0],   # cat
115    [0.0, 0.0, 1.0, 0.0],   # sat
116    [0.0, 0.0, 0.0, 1.0],   # on
117    [0.5, 0.5, 0.5, 0.5],   # mat
118])
119
120# ── Run ──
121attn = CausalAttention(d_k=4)
122weights, output = attn.forward(Q, K, V)
123
124print("Causal Attention Weights (5x5):")
125print(np.round(weights, 4))
126
127print("\nCausal Output (5x4):")
128print(np.round(output, 4))
129
130# Detailed trace for "sat"
131attn.explain(Q, K, V, tokens, query_idx=2)

PyTorch Implementation

The PyTorch implementation below includes both a trainable module (with learned WQ,WK,WVW_Q, W_K, W_V projections) and a static method for our manual simulation. The final section demonstrates the production-readyF.scaled_dot_product_attention\texttt{F.scaled\_dot\_product\_attention} API withis_causal=True\texttt{is\_causal=True}, which automatically uses Flash Attention when available.

Causal Attention \u2014 PyTorch Implementation
🐍causal_attention_torch.py
1import torch

PyTorch provides GPU-accelerated tensor operations and autograd for backpropagation.

2import torch.nn as nn

Neural network module. nn.Linear provides learned weight matrices.

3import torch.nn.functional as F

Functional API: F.softmax, F.scaled_dot_product_attention.

4import math

For math.sqrt(d_k), the scaling factor.

6class CausalAttention(nn.Module)

PyTorch module for causal self-attention. Inherits nn.Module for parameter tracking and GPU support.

16def __init__(self, d_model, d_k)

Initialize with d_model (input dimension) and d_k (projection dimension). Creates three learned projection matrices.

EXECUTION STATE
⬇ input: d_model = embedding dimension (e.g. 512)
⬇ input: d_k = 4 (query/key dimension)
18self.d_k = d_k

Store d_k for scaling.

19self.W_Q = nn.Linear(d_model, d_k, bias=False)

Learned query projection. Transforms input from d_model to d_k dimensions. bias=False because the original transformer omits bias in attention projections.

EXECUTION STATE
W_Q.weight = shape (d_k, d_model) — learned during training
20self.W_K = nn.Linear(d_model, d_k, bias=False)

Learned key projection. Same shape as W_Q.

21self.W_V = nn.Linear(d_model, d_k, bias=False)

Learned value projection. Transforms inputs into the 'content' representation.

23def forward(self, x) → tuple

Forward pass with learned projections. Takes raw embeddings x, projects to Q/K/V, and applies causal attention.

EXECUTION STATE
⬇ input: x = torch.Tensor shape (N, d_model)
25N = x.size(0)

Number of tokens in the sequence.

EXECUTION STATE
N = 5
26Q = self.W_Q(x)

Project x through W_Q to get queries. Shape: (N, d_k).

27K = self.W_K(x)

Project x through W_K to get keys. Shape: (N, d_k).

28V = self.W_V(x)

Project x through W_V to get values. Shape: (N, d_k).

29scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)

Scaled dot-product scores. K.transpose(-2, -1) swaps the last two dims. Divide by sqrt(d_k) = 2.0.

EXECUTION STATE
.transpose(-2, -1) = swap last two dimensions: (N, d_k) → (d_k, N). Equivalent to .T for 2D tensors.
30mask = torch.triu(..., diagonal=1)

Upper-triangular boolean mask. diagonal=1 means start 1 above the main diagonal, so the diagonal itself (self-attention) is NOT masked.

EXECUTION STATE
diagonal=1 = equivalent to NumPy k=1. Start masking 1 position above the main diagonal.
mask (5×5) =
       The    cat    sat     on    mat
The  False   True   True   True   True
cat  False  False   True   True   True
sat  False  False  False   True   True
on   False  False  False  False   True
mat  False  False  False  False  False
34scores = scores.masked_fill(mask, float('-inf'))

PyTorch's masked_fill replaces True positions with -inf. This is cleaner than boolean indexing.

EXECUTION STATE
.masked_fill(mask, val) = where mask is True, replace with val. Non-masked positions unchanged.
float('-inf') = negative infinity — exp(-inf) = 0 in softmax
35weights = F.softmax(scores, dim=-1)

Apply softmax along the last dimension (columns). Each row becomes a probability distribution summing to 1.0. Positions with -inf get weight 0.

EXECUTION STATE
dim=-1 = softmax along the last axis (key dimension). Each row independently sums to 1.0.
36output = weights @ V

Weighted sum of value vectors. Shape: (N, d_k).

37return output, weights

Return output tensor and attention weights.

39@staticmethod

Decorator: this method doesn't use self (no learned weights). Allows calling CausalAttention.manual(Q, K, V) directly.

40def manual(Q, K, V)

Static method for running causal attention with pre-computed Q, K, V matrices (no learned projections). Used for our shared example.

EXECUTION STATE
⬇ input: Q = torch.Tensor (5, 4)
⬇ input: K = torch.Tensor (5, 4)
⬇ input: V = torch.Tensor (5, 4)
42d_k = K.size(-1)

Infer d_k from the key matrix.

EXECUTION STATE
d_k = 4
43N = Q.size(0)

Number of tokens.

EXECUTION STATE
N = 5
44scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)

Same as forward(): scaled dot-product scores.

EXECUTION STATE
scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
45mask = torch.triu(..., diagonal=1)

Build the causal mask. Same as forward().

46scores = scores.masked_fill(mask, float('-inf'))

Apply causal mask: future positions → -inf.

47weights = F.softmax(scores, dim=-1)

Softmax with -inf positions zeroed out.

EXECUTION STATE
weights (5×5) =
       The     cat     sat      on     mat
The  1.0000  0.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000  0.0000
on   0.2350  0.2350  0.1425  0.3875  0.0000
mat  0.1892  0.1892  0.1892  0.1892  0.2430
48output = weights @ V

Final output: weighted sum of visible value vectors.

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  1.0000  0.0000  0.0000  0.0000
cat  0.8176  0.1824  0.0000  0.0000
sat  0.2327  0.3837  0.3837  0.0000
on   0.2350  0.2350  0.1425  0.3875
mat  0.3108  0.3108  0.3108  0.3108
49return output, weights

Return (output, weights) tuple.

53tokens = ["The", "cat", "sat", "on", "mat"]

Shared example sentence.

55Q = torch.tensor([...])

Query matrix as a PyTorch tensor. Same values as the NumPy version.

63K = torch.tensor([...])

Key matrix as a PyTorch tensor.

71V = torch.tensor([...])

Value matrix as a PyTorch tensor.

79output, weights = CausalAttention.manual(Q, K, V)

Run causal attention using the static method (no learned projections, uses our exact Q/K/V).

EXECUTION STATE
output shape = torch.Size([5, 4])
weights shape = torch.Size([5, 5])
84builtin_output = F.scaled_dot_product_attention(..., is_causal=True)

PyTorch 2.0+ built-in attention with is_causal=True. This is the recommended production API — it automatically fuses operations and uses Flash Attention when available.

EXECUTION STATE
.unsqueeze(0) = add a batch dimension: (5,4) → (1,5,4). SDPA expects (batch, seq, dim).
is_causal=True = automatically applies causal mask internally. No need to build/pass the mask manually.
.squeeze(0) = remove the batch dimension: (1,5,4) → (5,4).
51 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class CausalAttention(nn.Module):
7    """
8    Causal (Masked) Self-Attention in PyTorch.
9
10    Supports both:
11    1. Learnable projections (W_Q, W_K, W_V) for training
12    2. Pre-computed Q, K, V for manual simulation
13    """
14
15    def __init__(self, d_model: int, d_k: int):
16        super().__init__()
17        self.d_k = d_k
18        self.W_Q = nn.Linear(d_model, d_k, bias=False)
19        self.W_K = nn.Linear(d_model, d_k, bias=False)
20        self.W_V = nn.Linear(d_model, d_k, bias=False)
21
22    def forward(self, x: torch.Tensor) -> tuple:
23        """Forward pass with learned projections."""
24        N = x.size(0)
25        Q = self.W_Q(x)
26        K = self.W_K(x)
27        V = self.W_V(x)
28        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
29        mask = torch.triu(
30            torch.ones(N, N, dtype=torch.bool, device=x.device),
31            diagonal=1
32        )
33        scores = scores.masked_fill(mask, float('-inf'))
34        weights = F.softmax(scores, dim=-1)
35        output = weights @ V
36        return output, weights
37
38    @staticmethod
39    def manual(Q, K, V):
40        """Run with pre-computed Q, K, V (no learned weights)."""
41        d_k = K.size(-1)
42        N = Q.size(0)
43        scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
44        mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)
45        scores = scores.masked_fill(mask, float('-inf'))
46        weights = F.softmax(scores, dim=-1)
47        output = weights @ V
48        return output, weights
49
50
51# ── Manual simulation with shared example ──
52tokens = ["The", "cat", "sat", "on", "mat"]
53
54Q = torch.tensor([
55    [1.0, 0.0, 1.0, 0.0],
56    [0.0, 2.0, 0.0, 1.0],
57    [1.0, 1.0, 1.0, 0.0],
58    [0.0, 0.0, 1.0, 1.0],
59    [1.0, 0.0, 0.0, 1.0],
60])
61
62K = torch.tensor([
63    [0.0, 1.0, 0.0, 1.0],
64    [1.0, 0.0, 1.0, 0.0],
65    [1.0, 1.0, 0.0, 0.0],
66    [0.0, 0.0, 1.0, 1.0],
67    [1.0, 0.0, 0.5, 0.5],
68])
69
70V = torch.tensor([
71    [1.0, 0.0, 0.0, 0.0],
72    [0.0, 1.0, 0.0, 0.0],
73    [0.0, 0.0, 1.0, 0.0],
74    [0.0, 0.0, 0.0, 1.0],
75    [0.5, 0.5, 0.5, 0.5],
76])
77
78output, weights = CausalAttention.manual(Q, K, V)
79print("Causal Weights:\n", weights.round(decimals=4))
80print("Causal Output:\n", output.round(decimals=4))
81
82# Using PyTorch built-in (recommended for production)
83with torch.no_grad():
84    builtin_output = F.scaled_dot_product_attention(
85        Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0),
86        is_causal=True
87    ).squeeze(0)
88print("\nBuilt-in Output:\n", builtin_output.round(decimals=4))

Key Takeaways

  1. Causal masking is the only difference from standard attention. It sets the upper triangle of the score matrix to -\infty before softmax. No new parameters, no new computation — just information removal.
  2. It enforces the autoregressive constraint. Token ii can only attend to tokens 0,1,,i0, 1, \ldots, i. This prevents the model from “cheating” by looking at future tokens during training.
  3. Training and inference are consistent. Because the mask blocks future tokens during training (where the full sequence is available), the model learns to operate the same way it must at inference time (where future tokens do not exist).
  4. The last token sees everything. For the last position in the sequence, causal attention is identical to standard attention. Earlier tokens are progressively more constrained.
  5. It enables the KV-cache. Because past computations never change, keys and values can be cached and reused during generation, turning each step from O(N2)O(N^2) to O(N)O(N).
  6. Flash Attention makes it faster, not slower. Despite the mask, Flash Attention's causal mode skips the upper triangle entirely, achieving ~50% fewer FLOPs than standard attention.

Exercises

Exercise 1: Compute “on”'s Output by Hand

Token “on” (position 3) can see “The”, “cat”, “sat”, and itself. Using the causal attention weights from this chapter, compute OonO_{\text{on}} by taking the weighted sum of the four visible value vectors. Verify that your result matches [0.2350,0.2350,0.1425,0.3875][0.2350, 0.2350, 0.1425, 0.3875].

Exercise 2: What If “The” Could See “cat”?

In causal attention, “The” gets output [1,0,0,0][1, 0, 0, 0] (it only sees itself). In standard attention, it gets [0.2254,0.4135,0.2964,0.2964][0.2254, 0.4135, 0.2964, 0.2964]. Explain conceptually why this difference matters for a language model during training.

Exercise 3: Bidirectional Mask for Position 2

Suppose you wanted token “sat” to use standard (bidirectional) attention while all other tokens remain causal. Modify the mask matrix for this scenario and recompute the weights for row 2. How does it differ from pure causal?

Exercise 4: Implement Prefix-LM Masking

Write a function prefix_lm_mask(N, P)\texttt{prefix\_lm\_mask(N, P)} that returns a mask where positions 0,,P10, \ldots, P-1 (the prefix) use bidirectional attention and positions P,,N1P, \ldots, N-1 use causal attention. Apply it to our example with P=2P = 2 (so “The” and “cat” are bidirectional, and “sat”, “on”, “mat” are causal).

Exercise 5: KV-Cache Simulation

Simulate autoregressive generation with KV-caching. Starting with only “The,” compute the output for each position as you add tokens one at a time. At each step, append the new K and V vectors to the cache and compute attention for only the new query against all cached keys. Verify that your results match the full causal attention output.


References

  1. Vaswani, A. et al. (2017). “Attention Is All You Need.” Advances in NeurIPS, 30. arXiv:1706.03762.
  2. Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018). “Improving Language Understanding by Generative Pre-Training.” OpenAI Technical Report.
  3. Radford, A. et al. (2019). “Language Models are Unsupervised Multitask Learners.” OpenAI Technical Report (GPT-2).
  4. Brown, T. B. et al. (2020). “Language Models are Few-Shot Learners.” Advances in NeurIPS, 33 (GPT-3). arXiv:2005.14165.
  5. Devlin, J., Chang, M., Lee, K., & Toutanova, K. (2018). “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv:1810.04805.
  6. Dao, T., Fu, D., Ermon, S., Rudra, A., & Re, C. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” Advances in NeurIPS, 35. arXiv:2205.14135.
  7. Tay, Y. et al. (2023). “UL2: Unifying Language Learning Paradigms.” ICLR 2023. arXiv:2205.05131.
  8. Das, A. et al. (2024). “A Decoder-only Foundation Model for Time-Series Forecasting.”ICML 2024 (TimesFM). arXiv:2310.10688.
Loading comments...