Chapter 8
15 min read
Section 43 of 75

Complete Transformer Decoder

Transformer Decoder

Introduction

In the previous section we built one TransformerDecoderLayer — a single block that combines masked self-attention, cross-attention over encoder memory, and a position-wise FFN with LayerNorm and residuals between them. One layer, however, is not a decoder. A decoder that can translate, summarize, or generate language needs three additional pieces:

  1. A token embedding that turns integer ids into d_model-dim vectors.
  2. A positional encoding that injects order information.
  3. A stack of N decoder layers that iteratively refine the hidden state, followed by an output projection that turns the final hidden state into vocabulary logits.

This section assembles those pieces into TransformerDecoder — the second-to-last building block before the full encoder-decoder Transformer in §6. We'll cover the math of the stacked forward pass, the clever parameter-sharing trick called weight tying, a step-by-step numerical walkthrough, the PyTorch implementation, and a careful parameter count.


Why Stack Decoder Layers

Shallow models can't capture hierarchical structure

A single attention block can only mix information once. Language has structure at many scales — morphology, syntax, semantics, discourse — and those scales compose. No finite amount of width can make up for depth: you eventually need to build abstractions on top of abstractions, which is exactly what stacking does.

Empirically, translation quality monotonically improves with depth up to N=6 in Vaswani et al. 2017 and keeps scaling in modern large language models. Kaplan et al. 2020 ("Scaling Laws for Neural Language Models") fit loss ≈ N^(−0.076), and Hoffmann et al. 2022 (Chinchilla) showed that compute-optimal training requires balancing parameters and data — adding layers alone isn't enough if you don't train on proportionally more tokens.

What each layer learns — the BERTology view, adapted

Tenney et al. 2019 ("BERT Rediscovers the Classical NLP Pipeline") and Rogers et al. 2020 ("A Primer in BERTology") probed encoder layers and found a rough hierarchy: shallow layers handle local and morphological patterns, middle layers capture syntax, and deep layers represent semantics and discourse. Decoders exhibit an analogous gradient. Shallow decoder layers tend to specialize in copying source content via cross-attention; middle layers handle local target-side fluency; deep layers assemble longer-range coherence across the target sentence.

Why This Matters: Stacking is not just repetition. Each layer starts where the last one stopped — so Layer 3 sees a syntactically organized representation, not raw embeddings. The shared weights between parallel heads in one layer give you breadth; stacking layers gives you depth. Both are needed.

Decoder Forward Pass

With the target input tokens tgt\mathrm{tgt} and the encoder output memory\mathrm{memory} in hand, the decoder's forward pass has three stages:

Stage 1 — build h₀

Look up embeddings, scale them, add positional encoding, apply dropout: h0=Dropout(Embed(tgt)dmodel+PE(tgt))h_0 = \mathrm{Dropout}\bigl(\mathrm{Embed}(\mathrm{tgt}) \cdot \sqrt{d_{model}} + \mathrm{PE}(\mathrm{tgt})\bigr).

The dmodel\sqrt{d_{model}} scaling is explained in Vaswani 2017 §3.4. PyTorch's nn.Embedding initializes entries around N(0,1)\mathcal{N}(0, 1), which means each embedding vector has variance per feature ≈ 1. Sinusoidal PE has amplitude ≈ 1 as well. Without scaling, the PE component would dominate the embedding. Multiplying the embedding by dmodel\sqrt{d_{model}} (≈ 22.6 when d_model = 512) restores a balance where the identity of the token and its position both matter.

Stage 2 — iterate the stack

For layers =1,,N\ell = 1, \dots, N: h=DecoderLayer(h1,memory,tgt_mask,memory_mask)h_\ell = \mathrm{DecoderLayer}_\ell(h_{\ell-1}, \mathrm{memory}, \mathrm{tgt\_mask}, \mathrm{memory\_mask}).

Every layer receives the same memory and masks; only hh changes. The layer internals — masked self-attn, cross-attn, FFN, pre-norm, residuals, dropout — were specified in §4 and are not reopened here.

Stage 3 — project to vocabulary

A final LayerNorm hNLayerNorm(hN)h_N \leftarrow \mathrm{LayerNorm}(h_N) (required when using pre-norm — without it the last residual sum has unbounded scale), then: logits=hNWout\mathrm{logits} = h_N \, W_{out}^\top.

The logits have shape [B,Ttgt,V][B, T_{tgt}, V] where V is the target vocabulary size. Softmax over the last axis gives token probabilities; cross-entropy against the gold ids gives the training loss.


Weight Tying

Notice that the embedding table EmbedRV×dmodel\mathrm{Embed} \in \mathbb{R}^{V \times d_{model}} and the output projection WoutRV×dmodelW_{out} \in \mathbb{R}^{V \times d_{model}} have the same shape. Press & Wolf ("Using the Output Embedding to Improve Language Models", 2017) proposed using the same matrix for both: Wout=EmbedW_{out} = \mathrm{Embed}.

This has two benefits. First, it halves the parameters at the two largest matrices — for V=50,000V = 50{,}000 and dmodel=512d_{model} = 512 that's 25.6M saved parameters. Second, it often improves quality: the inner product hEmbedwh \cdot \mathrm{Embed}_w directly measures how close hh is to the embedding of candidate word ww, which is the natural dual of the lookup operation done at the input.

When to tie: share a single vocabulary (language modeling, code models, shared-vocab translation). When not: separate source and target vocabularies (classical NMT). You could still tie the target embedding with the output projection — that's actually how Vaswani 2017 did it for their WMT models.

Plain-Python Stack Walkthrough

Before we reach for PyTorch, let's stack two tiny decoder layers by hand on the shared config (B=1,Tsrc=4,Ttgt=3,dmodel=8,H=2,dff=16,N=2)(B{=}1, T_{src}{=}4, T_{tgt}{=}3, d_{model}{=}8, H{=}2, d_{ff}{=}16, N{=}2). We'll watch the hidden state h0h1h2h_0 \to h_1 \to h_2 evolve across the stack using deterministic seeds so every number is reproducible.

Plain-Python stack of 2 decoder layers (numpy)
🐍decoder_stack_numpy.py
1import numpy as np

NumPy is the numerical workhorse used here. We use ndarray for all tensors, @ for matrix multiplication, broadcasting for element-wise operations, and np.random for reproducible pseudo-random weights. Aliased as np by convention.

EXECUTION STATE
numpy = N-dimensional array library — provides ndarray, linear algebra, random numbers.
as np = Alias so we write np.random.seed() instead of numpy.random.seed().
3np.random.seed(0)

Fixes the global NumPy RNG so every run of this code produces the same random numbers — essential for a book walkthrough where the reader must see the same values we describe.

EXECUTION STATE
📚 np.random.seed(seed) = Initializes NumPy's legacy global random generator. All subsequent calls to np.random.randn() follow a deterministic sequence.
⬇ arg: seed = 0 = Arbitrary but conventional integer. Any fixed value works; the important thing is that it is fixed.
5Shared config comment

Documents the tiny hyperparameters we share across sections §4, §5, and §6 so the reader can follow the same example all the way through.

6B, T_src, T_tgt, d_model, H, d_ff, N = 1, 4, 3, 8, 2, 16, 2

Tuple-unpacks the seven hyperparameters in one line. These names match the notation table in the chapter intro.

EXECUTION STATE
B = 1 = Batch size. One example at a time so individual numbers are visible.
T_src = 4 = Source (encoder memory) sequence length — 4 tokens.
T_tgt = 3 = Target sequence length — 3 tokens the decoder is producing.
d_model = 8 = Model embedding dimension. Every token is an 8-vector at every layer boundary.
H = 2 = Number of attention heads. d_model splits evenly: 8 ÷ 2 = 4 per head.
d_ff = 16 = Feed-forward hidden dimension. Usually ~4× d_model; here 2×d_model to keep it tiny.
N = 2 = Number of stacked decoder layers. Real Transformers use 6, 12, 48, 96.
7d_k = d_model // H # 4

Per-head key/query/value dimension. Each of the H heads projects into d_k = d_model/H so the concatenation of all heads has exactly d_model columns again.

EXECUTION STATE
// (floor division) = Integer division. 8 // 2 = 4. Uses // not / so d_k stays an int usable as a reshape dimension.
d_k = 4 — each head's projection dimension.
9Comment: h_0 construction note

Reminder that in a full Transformer, h_0 is Dropout(Embed(tgt) · sqrt(d_model) + PE(tgt)). For clarity we use a fixed random h_0 here — the embedding / positional-encoding machinery was covered in ch04.

11tgt_in = (np.random.randn(T_tgt, d_model) * 0.5).astype(np.float32)

Creates the starting hidden state h_0 for the 3 target tokens. Small variance (×0.5) keeps activations from blowing up through the layer stack.

EXECUTION STATE
📚 np.random.randn(d0, d1) = Draws standard-normal samples, shape (d0, d1). Mean 0, variance 1.
⬇ arg: T_tgt = 3 = Number of rows — one row per target token.
⬇ arg: d_model = 8 = Number of columns — 8 features per token.
* 0.5 = Rescales variance so initial values sit around ±0.5.
.astype(np.float32) = Converts to 32-bit float — the PyTorch default, matching what we will compute later.
⬆ tgt_in (3×8) =
        d0     d1     d2     d3     d4     d5     d6     d7
tok0  0.882  0.200  0.489  1.120  0.934 -0.489  0.475 -0.076
tok1 -0.052  0.205  0.072  0.727  0.381  0.061  0.222  0.167
tok2  0.747 -0.103  0.157 -0.427 -1.276  0.327  0.432 -0.371
12memory = (np.random.randn(T_src, d_model) * 0.5).astype(np.float32)

Encoder memory — 4 source tokens, each d_model=8. In the real pipeline this would come from the encoder; here we fabricate it so the decoder has something to cross-attend to.

EXECUTION STATE
⬇ args: (T_src=4, d_model=8) = 4 rows × 8 columns — one row per source token.
⬆ memory (4×8) = Kept as the K and V source for every decoder layer's cross-attention. Values not printed for space; drawn from the same seed.
14def layernorm(x, eps=1e-5)

Pre-norm helper. Normalizes each row to zero mean, unit variance across the d_model axis (the LAST axis), then returns it. No learnable scale/bias — those stayed in §4's implementation.

EXECUTION STATE
⬇ input: x — any 2D ndarray with d_model last = We pass (T_tgt, d_model) or (T_src, d_model) tensors.
⬇ input: eps = 1e-5 = Small constant added inside the sqrt to avoid division by zero when variance is ~0.
15mu = x.mean(-1, keepdims=True)

Row-wise mean across d_model features. keepdims=True gives shape (T, 1) so broadcasting against x (T, d_model) works without a reshape.

EXECUTION STATE
📚 ndarray.mean(axis, keepdims) = Averages along an axis. axis=-1 = last axis.
⬇ arg: -1 = Last axis (d_model). For shape (3, 8), averages across the 8 features per row.
⬇ arg: keepdims=True = Result shape (3, 1) instead of (3,) — needed for broadcasting `x - mu`.
⬆ mu (3×1) = mean of each row's 8 values.
16var = x.var(-1, keepdims=True)

Row-wise variance across d_model. Same shape trick as the mean.

EXECUTION STATE
📚 ndarray.var(axis, keepdims) = Computes mean((x - mean)^2) along the given axis.
⬆ var (3×1) = per-row variance — measures spread of each token's 8 features.
17return (x - mu) / np.sqrt(var + eps)

Subtract row mean, divide by row std-dev. Because mu is (3,1) and x is (3,8), NumPy broadcasts mu across all 8 columns.

EXECUTION STATE
(x - mu) = Centers each row at 0.
np.sqrt(var + eps) = Row-wise std-dev (eps avoids zero).
⬆ return = Zero-mean, unit-variance rows. Shape unchanged: (T, d_model).
19def softmax(x, axis=-1)

Numerically stable softmax. Used inside attention to turn raw scores into probabilities.

20x = x - x.max(axis=axis, keepdims=True)

Subtract the axis-max BEFORE exponentiating. This log-sum-exp trick prevents exp(large) overflow without changing the result.

EXECUTION STATE
📚 ndarray.max(axis, keepdims) = Max along an axis with same keepdims semantics as .mean/.var.
⬇ arg: axis = -1 = Last axis — for attention scores (H, T_q, T_k), that's over keys.
21e = np.exp(x)

Element-wise exponential. After the shift, all inputs are ≤ 0, so every exp() is in (0, 1].

EXECUTION STATE
📚 np.exp(x) = Element-wise e^x. np.exp(0) = 1.0, np.exp(-1) ≈ 0.368.
22return e / e.sum(axis=axis, keepdims=True)

Divide by the sum along the same axis so each slice integrates to 1 — a valid probability distribution.

EXECUTION STATE
⬆ return = softmax output, same shape as input x.
24def mha(Q, K, V, W, causal=False)

Multi-head attention — used three different ways inside decoder_layer: masked self-attention (Q=K=V, causal=True), cross-attention (Q=target, K=V=memory, causal=False), and nothing else. W is a dict of the 4 projection matrices.

EXECUTION STATE
⬇ input: Q = Query source — shape (T_q, d_model). For self-attn this IS x; for cross-attn it's x (target).
⬇ input: K, V = Key/Value source — shape (T_k, d_model). Same as Q for self-attn; memory for cross-attn.
⬇ input: W = Dict with 'q','k','v','o' each a (d_model, d_model) projection matrix.
⬇ input: causal = If True, upper-triangular mask blocks future positions. Used for self-attn on the target, never for cross-attn.
25q = (Q @ W['q']).reshape(-1, H, d_k).transpose(1, 0, 2)

Project Q, split into H heads, then move the head axis to the front so matmuls later treat heads as a batch.

EXECUTION STATE
Q @ W['q'] = Matrix multiply (T_q, d_model) @ (d_model, d_model) → (T_q, d_model).
📚 .reshape(-1, H, d_k) = Splits the 8-dim vector into 2 heads × 4 dims. -1 lets NumPy infer the first dim as T_q.
📚 .transpose(1, 0, 2) = Permute axes so the resulting shape is (H, T_q, d_k) — heads become the leading 'batch' dim.
→ Example = Before transpose: (T_q=3, H=2, d_k=4). After: (H=2, T_q=3, d_k=4). Same numbers, different layout.
26k = (K @ W['k']).reshape(-1, H, d_k).transpose(1, 0, 2)

Same projection + head-split pattern for K. The result has shape (H, T_k, d_k).

27v = (V @ W['v']).reshape(-1, H, d_k).transpose(1, 0, 2)

Same for V. Shape (H, T_k, d_k).

28scores = q @ k.transpose(0, 2, 1) / np.sqrt(d_k)

Per-head scaled dot products. q is (H, T_q, d_k); k.transpose(0,2,1) is (H, d_k, T_k); the @ gives (H, T_q, T_k).

EXECUTION STATE
.transpose(0, 2, 1) = Swap the LAST two axes of k so d_k lines up for matrix multiply.
/ np.sqrt(d_k) = Scaling: √4 = 2. Keeps score variance ~1 regardless of d_k; prevents softmax saturation.
29if causal:

Only applies to masked self-attention (§4). Skipped entirely in cross-attention because the target may legitimately see all of memory.

30Tq = scores.shape[-2]

Read the query length from the current scores tensor — inside self-attn this is T_tgt=3.

31mask = np.triu(np.ones((Tq, Tq)), k=1).astype(bool)

Upper-triangular True matrix — True exactly where the query position must NOT see the key position.

EXECUTION STATE
📚 np.triu(A, k) = Upper-triangular of A, zeroing entries below the k-th diagonal. k=1 leaves the diagonal zero (self-attention on the current position is allowed).
⬇ arg: k=1 = Shift the diagonal up by 1 so positions can attend to themselves.
32scores = np.where(mask, -1e9, scores)

Replace masked entries with a huge negative number so softmax sends them to zero.

EXECUTION STATE
📚 np.where(cond, a, b) = Elementwise: where cond is True take a, else take b. Broadcasts shapes as needed.
33out = (softmax(scores) @ v).transpose(1, 0, 2).reshape(-1, d_model)

softmax gives (H, T_q, T_k). Multiplying by v (H, T_k, d_k) yields (H, T_q, d_k). Then transpose back to (T_q, H, d_k) and reshape to (T_q, d_model) to merge the heads.

34return out @ W['o']

Final output projection. (T_q, d_model) @ (d_model, d_model) → (T_q, d_model). This is the W_o in Vaswani 2017.

EXECUTION STATE
⬆ return = Attention output for every query position. Shape (T_q, d_model).
36def ffn(x, W)

Position-wise feed-forward: two linear layers with ReLU between them. Applied identically to every token.

37return np.maximum(0, x @ W['w1']) @ W['w2']

x @ W['w1']: (T, d_model) → (T, d_ff). ReLU. Then @ W['w2']: (T, d_ff) → (T, d_model).

EXECUTION STATE
📚 np.maximum(0, z) = Elementwise ReLU: max(0, z). Negative entries become 0.
39def decoder_layer(x, memory, W)

One full decoder layer: masked self-attention, then cross-attention over memory, then FFN. Each wrapped in a pre-norm + residual.

EXECUTION STATE
⬇ input: x = Hidden state entering this layer — shape (T_tgt=3, d_model=8).
⬇ input: memory = Encoder output — shape (T_src=4, d_model=8). Same object passed to every layer.
⬇ input: W = Dict of this layer's weights: W['sa'], W['ca'], W['ff'].
40x = x + mha(layernorm(x), layernorm(x), layernorm(x), W['sa'], causal=True)

Pre-norm self-attention sublayer. LayerNorm is applied BEFORE the attention (modern pre-norm variant). Q=K=V because it's self-attention. causal=True because this is the decoder and we must not peek at future target tokens.

EXECUTION STATE
x + ... = Residual connection — preserves the signal from h_0 through every sublayer.
causal=True = Triggers the upper-triangular mask inside mha().
41x = x + mha(layernorm(x), layernorm(memory), layernorm(memory), W['ca'])

Cross-attention sublayer. Q comes from the DECODER side; K and V come from the ENCODER MEMORY. No causal mask — the decoder is allowed to look at every source position.

EXECUTION STATE
Q source = layernorm(x) — pre-normed target hidden state.
K, V source = layernorm(memory) — pre-normed encoder output. Same memory across all layers.
42x = x + ffn(layernorm(x), W['ff'])

Position-wise FFN sublayer, pre-normed and residualized.

43return x

Output hidden state for this layer — it becomes the next layer's input x.

EXECUTION STATE
⬆ return = Shape (T_tgt=3, d_model=8).
45def make_layer(seed)

Builds a fresh set of weights for one decoder layer. Using a different seed per layer guarantees distinct weights across the stack.

46r = np.random.default_rng(seed).standard_normal

Modern NumPy RNG — independent from np.random.seed so each layer's weights come from its own seeded stream.

EXECUTION STATE
📚 np.random.default_rng(seed) = Returns a Generator. .standard_normal(shape) draws N(0,1) samples.
47def P(shape):

Tiny helper that draws a matrix of the given shape, scales it by 0.2 (to keep activations bounded), and casts to float32.

48return (r(shape) * 0.2).astype(np.float32)

* 0.2 rescales standard-normal draws to std 0.2 — conservative initialization suitable for a toy demo.

49return {...layer weight dict...}

Dict with three sub-dicts: 'sa' (self-attn Q/K/V/O), 'ca' (cross-attn Q/K/V/O), and 'ff' (w1/w2 of the FFN). Every matrix is d_model×d_model except the FFN ones which are d_model×d_ff and d_ff×d_model.

50'sa': {'q': P((d_model, d_model)), ...}

Self-attention's four projection matrices, each (8, 8) here.

51 ... 'v': P((d_model, d_model)), 'o': P((d_model, d_model))

Closing out the 'sa' dict. Same shapes for all four matrices.

52'ca': {'q': P((d_model, d_model)), ...}

Cross-attention's four projection matrices. Different weights from self-attn, same shapes.

53 ... 'v': P((d_model, d_model)), 'o': P((d_model, d_model))

Closing the 'ca' dict.

54'ff': {'w1': P((d_model, d_ff)), 'w2': P((d_ff, d_model))}

FFN weights. w1 expands 8 → 16; w2 contracts 16 → 8.

57layers = [make_layer(1), make_layer(2)]

Two independently-seeded decoder layers. In PyTorch this will become an nn.ModuleList of TransformerDecoderLayer modules.

59h0 = tgt_in

Input to the stack. In a full model this would be Dropout(Embed(tgt)·√d_model + PE(tgt)).

EXECUTION STATE
h0 (3×8) =
        d0     d1     d2     d3     d4     d5     d6     d7
tok0  0.882  0.200  0.489  1.120  0.934 -0.489  0.475 -0.076
tok1 -0.052  0.205  0.072  0.727  0.381  0.061  0.222  0.167
tok2  0.747 -0.103  0.157 -0.427 -1.276  0.327  0.432 -0.371
60h1 = decoder_layer(h0, memory, layers[0])

First layer's forward pass. Notice the values shift: the residuals keep h0 visible, but self-attn, cross-attn, and FFN have each added a contribution.

EXECUTION STATE
h1 (3×8) =
        d0     d1     d2     d3     d4     d5     d6     d7
tok0  0.530  0.080  0.516  0.581  0.476 -0.617 -0.504 -0.161
tok1 -0.481  0.278  0.286  0.388 -0.082 -0.387 -1.063 -0.353
tok2  1.450  0.643  0.626 -0.598 -2.060 -0.356 -0.788 -0.535
→ What changed? = Every number moved from its h0 value. Each row is still anchored near its residual, but features have been remixed using attention and the FFN.
61h2 = decoder_layer(h1, memory, layers[1])

Second (and final here) layer. h2 is the representation we hand to the output projection to produce logits.

EXECUTION STATE
h2 (3×8) =
        d0     d1     d2     d3     d4     d5     d6     d7
tok0  0.822 -0.090  0.344  0.919  0.321 -0.648 -1.121  1.099
tok1 -0.085  0.031  0.075  0.987 -0.478 -0.556 -1.299 -0.026
tok2  1.363  0.431  0.835 -0.339 -2.172  0.040 -1.211 -0.804
→ Takeaway = Hidden states DO evolve across the stack — h2 is not a small perturbation of h1. Stacking lets each layer refine representations using both the target's own history (self-attn) and the source (cross-attn).
13 lines without explanation
1import numpy as np
2
3np.random.seed(0)
4
5# Shared tiny config used in §4, §5, §6
6B, T_src, T_tgt, d_model, H, d_ff, N = 1, 4, 3, 8, 2, 16, 2
7d_k = d_model // H  # 4
8
9# Already-prepared h_0: Embed(tgt) * sqrt(d_model) + PE(tgt), then dropout.
10# We build it by hand so the numbers are traceable.
11tgt_in = (np.random.randn(T_tgt, d_model) * 0.5).astype(np.float32)
12memory = (np.random.randn(T_src, d_model) * 0.5).astype(np.float32)
13
14def layernorm(x, eps=1e-5):
15    mu = x.mean(-1, keepdims=True)
16    var = x.var(-1, keepdims=True)
17    return (x - mu) / np.sqrt(var + eps)
18
19def softmax(x, axis=-1):
20    x = x - x.max(axis=axis, keepdims=True)
21    e = np.exp(x)
22    return e / e.sum(axis=axis, keepdims=True)
23
24def mha(Q, K, V, W, causal=False):
25    q = (Q @ W["q"]).reshape(-1, H, d_k).transpose(1, 0, 2)
26    k = (K @ W["k"]).reshape(-1, H, d_k).transpose(1, 0, 2)
27    v = (V @ W["v"]).reshape(-1, H, d_k).transpose(1, 0, 2)
28    scores = q @ k.transpose(0, 2, 1) / np.sqrt(d_k)
29    if causal:
30        Tq = scores.shape[-2]
31        mask = np.triu(np.ones((Tq, Tq)), k=1).astype(bool)
32        scores = np.where(mask, -1e9, scores)
33    out = (softmax(scores) @ v).transpose(1, 0, 2).reshape(-1, d_model)
34    return out @ W["o"]
35
36def ffn(x, W):
37    return np.maximum(0, x @ W["w1"]) @ W["w2"]
38
39def decoder_layer(x, memory, W):
40    x = x + mha(layernorm(x), layernorm(x), layernorm(x), W["sa"], causal=True)
41    x = x + mha(layernorm(x), layernorm(memory), layernorm(memory), W["ca"])
42    x = x + ffn(layernorm(x), W["ff"])
43    return x
44
45def make_layer(seed):
46    r = np.random.default_rng(seed).standard_normal
47    def P(shape):
48        return (r(shape) * 0.2).astype(np.float32)
49    return {
50        "sa": {"q": P((d_model, d_model)), "k": P((d_model, d_model)),
51               "v": P((d_model, d_model)), "o": P((d_model, d_model))},
52        "ca": {"q": P((d_model, d_model)), "k": P((d_model, d_model)),
53               "v": P((d_model, d_model)), "o": P((d_model, d_model))},
54        "ff": {"w1": P((d_model, d_ff)), "w2": P((d_ff, d_model))},
55    }
56
57layers = [make_layer(1), make_layer(2)]
58
59h0 = tgt_in
60h1 = decoder_layer(h0, memory, layers[0])
61h2 = decoder_layer(h1, memory, layers[1])

The takeaway from the printed values of h0,h1,h2h_0, h_1, h_2 is that hidden states really do move as they pass through the stack. Token 2 ("tok2") starts at d4=1.276d_4 = -1.276, ends at d4=2.172d_4 = -2.172; its d0d_0 rises from 0.747 to 1.363. These shifts are not noise — they are the cumulative effect of attention and FFN remixing features using information the layer couldn't access before (the target's own history via self-attn, the source via cross-attn).


PyTorch Implementation

Now the production version. The class TransformerDecoder composes TransformerDecoderLayer (from §4), a token embedding, a positional encoding, a final LayerNorm, and an output projection — with the weight-tying toggle baked in.

TransformerDecoder (PyTorch)
🐍transformer_decoder.py
1import math

Standard library. We use math.sqrt(d_model) in the embedding scaling and math.log(10000.0) inside sinusoidal PE.

2import torch

PyTorch core. Provides torch.Tensor, torch.arange, torch.sin/cos, and autograd.

3import torch.nn as nn

Neural-network building blocks. We use nn.Module, nn.Embedding, nn.Linear, nn.LayerNorm, nn.Dropout, and nn.ModuleList.

4from typing import Optional

Lets us annotate tgt_mask and memory_mask as Optional[torch.Tensor] — they default to None.

6Comment — reuse §4's layer

A full decoder stacks N identical TransformerDecoderLayer instances built in the previous section. We DO NOT re-implement the layer here.

7from .decoder_layer import TransformerDecoderLayer

Relative import of §4's class. Holds the three sublayers (masked self-attn, cross-attn, FFN) with their LayerNorms and residuals.

10class SinusoidalPositionalEncoding(nn.Module)

Copy of the ch04 PE module so this file is self-contained. See ch04/02 for the full derivation.

11Docstring

Reminds the reader this is a duplicate, not new content.

13def __init__(self, d_model, max_len=5000)

Builds the (max_len, d_model) sinusoidal lookup table once.

EXECUTION STATE
⬇ input: d_model = Embedding dim. Must match the token-embedding dim so the sum PE + Embed is legal.
⬇ input: max_len = 5000 = Upper bound on sequence length. 5000 is plenty for most tasks; increase if you need longer contexts.
14super().__init__()

Registers this as an nn.Module so parameters and buffers are tracked.

15pe = torch.zeros(max_len, d_model)

Allocates the PE table. Shape (5000, d_model).

EXECUTION STATE
📚 torch.zeros(*shape) = Creates a tensor of zeros with the given shape. CPU by default.
16pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

Column vector of positions 0..max_len-1.

EXECUTION STATE
📚 torch.arange(start, stop, dtype) = Evenly-spaced values like Python range() but as a tensor.
📚 .unsqueeze(1) = Insert a new dim of size 1 at axis 1. Turns shape (max_len,) into (max_len, 1) so broadcasting against (d_model/2,) works.
17div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

Frequency coefficients — one per 2-dim pair. The log trick avoids raw 10000^(...) which would underflow quickly.

18pe[:, 0::2] = torch.sin(pos * div)

Even columns get sines. Broadcasting: (max_len, 1) * (d_model/2,) → (max_len, d_model/2).

EXECUTION STATE
0::2 (slice) = Start 0, step 2 — every even index.
19pe[:, 1::2] = torch.cos(pos * div)

Odd columns get cosines. Same frequencies, shifted phase.

20self.register_buffer('pe', pe.unsqueeze(0))

Register PE as a buffer — saved in state_dict and moved by .to(device), but NOT a learnable parameter.

EXECUTION STATE
📚 .register_buffer(name, tensor) = Stores a non-learnable tensor attached to the module.
📚 .unsqueeze(0) = Insert a batch dim → shape (1, max_len, d_model). Enables clean broadcasting against (B, T, d_model) in forward().
22def forward(self, x) -> Tensor

Adds the positional encoding to x.

23return x + self.pe[:, : x.size(1)]

Take only the first T rows of pe and add them to every batch element. Broadcasting: (1, T, d_model) + (B, T, d_model) → (B, T, d_model).

EXECUTION STATE
📚 x.size(1) = Returns the length of axis 1 — the sequence length T. Same as x.shape[1].
26class TransformerDecoder(nn.Module)

The complete decoder: token embedding, positional encoding, a stack of N TransformerDecoderLayer modules, a final LayerNorm, and the output projection to vocabulary logits.

27Docstring

One-line summary — kept short because the real documentation is this book.

29def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size, max_len=5000, dropout=0.1, tie_weights=True)

Builds every module that makes up the decoder. All arguments are explicit — no hidden defaults for architectural choices.

EXECUTION STATE
⬇ input: num_layers = N in the equations. Typical values: 6 (Vaswani), 12 (GPT-2 small), 48 (GPT-2 XL), 96 (GPT-3).
⬇ input: d_model = Embedding dimension, shared across embedding, PE, and every layer.
⬇ input: num_heads = Number of attention heads — must divide d_model evenly.
⬇ input: d_ff = FFN hidden dim. Usually 4·d_model.
⬇ input: vocab_size = Target-side vocabulary size. Determines both the Embedding table rows AND the output projection columns.
⬇ input: max_len = Longest supported target sequence length. 5000 is a sensible default.
⬇ input: dropout = Dropout probability applied to PE output and inside every sublayer. 0.1 is the standard Vaswani value.
⬇ input: tie_weights = If True, self.out_proj.weight IS self.tok_emb.weight — Press & Wolf 2017 weight tying. Halves the parameters at the two largest matrices.
39super().__init__()

Register as nn.Module.

40self.d_model = d_model

Stored so forward() can use it for the sqrt(d_model) scaling without needing a separate argument.

41self.tok_emb = nn.Embedding(vocab_size, d_model)

Learnable (vocab_size × d_model) table. Calling self.tok_emb(tgt) looks up rows for every token id in tgt.

EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim) = Lookup table of learnable vectors. Input: int64 tensor of indices. Output: same shape + embedding_dim appended.
⬇ arg 1: vocab_size = Number of rows. Each row represents one vocabulary item.
⬇ arg 2: d_model = Columns per row — the embedding vector length.
42self.pos_enc = SinusoidalPositionalEncoding(d_model, max_len)

Non-learnable PE module. We reuse ch04's sinusoidal flavor; RoPE/ALiBi would slot in here instead in a modern LLM.

43self.drop = nn.Dropout(dropout)

Single Dropout module applied to h_0 just after PE. Every sublayer inside TransformerDecoderLayer has its own dropout too.

EXECUTION STATE
📚 nn.Dropout(p) = During training, zeros each element with prob p and scales survivors by 1/(1-p). In eval mode it's identity.
44self.layers = nn.ModuleList([...])

ModuleList (NOT a plain Python list) so every layer's parameters are registered, show up in .parameters(), and move with .to(device).

EXECUTION STATE
📚 nn.ModuleList = A list-like container that registers every contained Module as a submodule.
45TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)

Constructs one decoder layer with the exact arguments it needs. See §4 for the internals.

46for _ in range(num_layers)

Builds num_layers copies, each with independent weights.

48self.norm = nn.LayerNorm(d_model)

Final LayerNorm after the stack. In pre-norm Transformers this normalizes the output before projection — without it, the last residual sum can have arbitrary scale.

49self.out_proj = nn.Linear(d_model, vocab_size, bias=False)

Projection to vocabulary logits. bias=False because weight tying replaces the weight matrix with the embedding table — keeping a bias would break the symmetry.

EXECUTION STATE
📚 nn.Linear(in, out, bias) = Fully-connected layer. output = x @ W.T + b (or no b).
⬇ arg 3: bias=False = Matches Vaswani 2017 and allows clean weight tying.
50if tie_weights:

Gate for weight tying. Useful mostly when the source and target share a vocabulary (language modeling, code models).

51self.out_proj.weight = self.tok_emb.weight # W_out = Embed

The SAME tensor object backs both the embedding table AND the output projection. Updating one updates the other. Saves vocab_size · d_model parameters.

EXECUTION STATE
Why it works = logits = h · out_proj.weight.T. When weight IS tok_emb.weight, we're asking 'which embedding vector is h closest to?' — the natural dual of lookup.
53def forward(self, tgt, memory, tgt_mask=None, memory_mask=None) -> Tensor

Full forward pass. Input tgt is a LongTensor of token ids; memory is the encoder output.

EXECUTION STATE
⬇ input: tgt = LongTensor shape (B, T_tgt). Each entry is a vocabulary id in [0, vocab_size).
⬇ input: memory = FloatTensor shape (B, T_src, d_model) — encoder output.
⬇ input: tgt_mask = Optional causal mask for self-attention. Typically an upper-triangular bool tensor of shape (T_tgt, T_tgt).
⬇ input: memory_mask = Optional padding mask over source positions. Shape (B, T_src) or broadcastable.
⬆ returns = FloatTensor shape (B, T_tgt, vocab_size) — unnormalized logits.
60h = self.tok_emb(tgt) * math.sqrt(self.d_model)

Look up embeddings and scale by √d_model. Per Vaswani 2017 §3.4, this keeps embedding magnitudes from being dominated by the PE (which has fixed amplitude ~1).

EXECUTION STATE
📚 math.sqrt = Pure Python square root, faster than torch.sqrt for a scalar.
Why √d_model? = nn.Embedding is default-initialized ~N(0, 1). With d_model=512 that's variance ~1 per feature. PE has amplitude 1. Multiplying by √d_model ≈ 22.6 makes Embed visible next to PE rather than drowned by it.
61h = self.drop(self.pos_enc(h))

Add PE, then apply dropout. This is the 'input dropout' from Vaswani 2017; in modern LLMs it's often 0.0 or removed entirely.

62for layer in self.layers:

Iterate through the N decoder layers in order. ModuleList supports direct iteration.

63h = layer(h, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)

Every layer takes the same memory and masks. Only h changes as we go up the stack. This is the code that realizes h_l = DecoderLayer_l(h_{l-1}, memory, ...).

64h = self.norm(h)

Final pre-norm LayerNorm. Without this, the output of the last layer is a raw residual sum with unbounded scale.

65logits = self.out_proj(h)

Project to vocabulary. With weight tying, this is mathematically h @ tok_emb.weight.T. Shape: (B, T_tgt, vocab_size).

66return logits

Unnormalized logits — the caller applies softmax (for probabilities), argmax (for greedy decoding), or a cross-entropy loss against target ids.

EXECUTION STATE
⬆ return = (B, T_tgt, vocab_size).
26 lines without explanation
1import math
2import torch
3import torch.nn as nn
4from typing import Optional
5
6# Assume TransformerDecoderLayer is the module built in §4.
7from .decoder_layer import TransformerDecoderLayer
8
9
10class SinusoidalPositionalEncoding(nn.Module):
11    """Copy of the ch04 PE module. Included for completeness."""
12
13    def __init__(self, d_model: int, max_len: int = 5000):
14        super().__init__()
15        pe = torch.zeros(max_len, d_model)
16        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
18        pe[:, 0::2] = torch.sin(pos * div)
19        pe[:, 1::2] = torch.cos(pos * div)
20        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)
21
22    def forward(self, x: torch.Tensor) -> torch.Tensor:
23        return x + self.pe[:, : x.size(1)]
24
25
26class TransformerDecoder(nn.Module):
27    """Full Transformer decoder: embedding + PE + N layers + output projection."""
28
29    def __init__(
30        self,
31        num_layers: int,
32        d_model: int,
33        num_heads: int,
34        d_ff: int,
35        vocab_size: int,
36        max_len: int = 5000,
37        dropout: float = 0.1,
38        tie_weights: bool = True,
39    ):
40        super().__init__()
41        self.d_model = d_model
42        self.tok_emb = nn.Embedding(vocab_size, d_model)
43        self.pos_enc = SinusoidalPositionalEncoding(d_model, max_len)
44        self.drop = nn.Dropout(dropout)
45        self.layers = nn.ModuleList([
46            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
47            for _ in range(num_layers)
48        ])
49        self.norm = nn.LayerNorm(d_model)
50        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
51        if tie_weights:
52            self.out_proj.weight = self.tok_emb.weight  # W_out = Embed
53
54    def forward(
55        self,
56        tgt: torch.Tensor,
57        memory: torch.Tensor,
58        tgt_mask: Optional[torch.Tensor] = None,
59        memory_mask: Optional[torch.Tensor] = None,
60    ) -> torch.Tensor:
61        h = self.tok_emb(tgt) * math.sqrt(self.d_model)
62        h = self.drop(self.pos_enc(h))
63        for layer in self.layers:
64            h = layer(h, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
65        h = self.norm(h)
66        logits = self.out_proj(h)
67        return logits
Note on KV caching: at inference time we want to feed one token at a time and reuse the keys/values the earlier layers already produced — this is called KV caching, and it becomes critical once your stack is 96 layers deep. We cover it in §9/05.

Parameter Counting

Let's tally parameters for a small demo configuration: dmodel=128d_{model} = 128, H=4H = 4, dff=512d_{ff} = 512, N=2N = 2, V=1000V = 1000. Per-layer numbers assume bias on every linear layer (Vaswani 2017) and three LayerNorms per layer.

ComponentFormulaParameters
Token embeddingV · d_model = 1000 · 128128,000
Positional encoding (sinusoidal)0 (no learnable params)0
Self-attention (W_q, W_k, W_v, W_o)4 · d_model² + 4 · d_model = 4·128² + 4·12866,048
Cross-attention (W_q, W_k, W_v, W_o)4 · d_model² + 4 · d_model66,048
FFN (W1, W2 with biases)d_model·d_ff + d_ff + d_ff·d_model + d_model = 128·512·2 + 512 + 128131,712
3 × LayerNorm (γ and β each d_model)3 · 2 · d_model = 6·128768
One TransformerDecoderLayer (sum of above)SA + CA + FFN + LN264,576
N=2 stacked layers2 · 264,576529,152
Output projection — untied (d_model·V + V)128·1000 + 1000129,000
Output projection — weight-tied to embedding0 new parameters0
Total (untied)embed + stack + out_proj_untied786,152
Total (weight-tied)embed + stack + 0657,152

Weight tying saves Vdmodel+VV \cdot d_{model} + V parameters — here 129,000 out of 786,152, roughly 16% of the model. At production scales (V=50k, d_model=4096), weight tying saves ~200M parameters.


Interactive Visualization

The visualizer below shows an N-deep decoder stack in 3D. Each layer is a column of three translucent planes — Self-Attn, Cross-Attn, FFN — with the encoder memory block sitting to the side. Press Step to push activations one sublayer at a time from the bottom of the stack upward, or Play for automatic stepping. Adjust N (1–6) and the grid resolution (a proxy for d_model) to see how the stack grows. When a cross-attention sublayer is active, a glowing line from the encoder memory to that layer indicates where K and V are sourced.

Loading decoder stack visualization...

Accessibility notes: all controls are tab-reachable with Enter/Space activation; the prefers-reduced-motion setting disables the auto-advance and you must use Step manually; on viewports under 768px a static SVG diagram is shown instead of the 3D canvas.


In Modern Systems

Depth across real models

ModelN (decoder layers)d_modelNotes
Vaswani 2017 (base)6512Original encoder-decoder
GPT-2 small12768Decoder-only
GPT-2 XL481600Decoder-only
GPT-3 175B9612288Decoder-only; Kaplan-style scaling
Llama-7B324096RMSNorm + rotary PE
Llama-65B / Llama-2-70B808192Grouped-query attention

Scaling laws

Kaplan et al. 2020 showed cross-entropy loss scales roughly as a power law in model size, data, and compute — with model-size exponents near N0.076N^{-0.076} in the regime they studied. Hoffmann et al. 2022 (Chinchilla) revised this picture: for a fixed compute budget, parameters and training tokens should scale in roughly equal ratio. Doubling N without doubling data is wasteful. This reframed many post-2022 training runs (e.g., Llama-2 trained on 2T tokens at 7–70B parameters).

Why decoders got very deep, very fast

In encoder-decoder translation models, depth tapered around N = 6–12 because most quality gains came from wider models and more data. Decoder-only LMs flipped that: Transformer weights are the only place a language model stores knowledge, so deep stacks directly increase model capacity. GPT-3's jump from 48 to 96 layers (and d_model from 1600 to 12288) bought dramatic gains in in-context learning. This is why modern frontier models are much deeper than the original Vaswani decoder.


Summary

  • A decoder is embedding + PE + N decoder layers + final norm + output projection.
  • Stack forward pass: h0=Dropout(Embed(tgt)dmodel+PE)h_0 = \mathrm{Dropout}(\mathrm{Embed}(\mathrm{tgt}) \cdot \sqrt{d_{model}} + \mathrm{PE}); h=DecoderLayer(h1,memory,)h_\ell = \mathrm{DecoderLayer}_\ell(h_{\ell-1}, \mathrm{memory}, \ldots); logits=hNWout\mathrm{logits} = h_N W_{out}^\top.
  • The dmodel\sqrt{d_{model}} factor balances embedding magnitude against the positional encoding amplitude.
  • Weight tying (Press & Wolf 2017) sets Wout=EmbedW_{out} = \mathrm{Embed}; use it when input and output vocabularies match.
  • Across the stack, hidden states are genuinely refined — not just rescaled. Shallow layers handle local structure, deep layers handle semantics and discourse.
  • Real decoders range from N = 6 (Vaswani) to N = 96 (GPT-3) to N = 80 (Llama-65B).
  • KV caching (§9/05) will become essential once you run inference on deep stacks.

Exercises

  1. (Easy) Modify the plain-Python walkthrough to run N = 4 layers instead of 2. Print the first row of each hh_\ell and plot how a single feature (e.g., h[0,0]h_\ell[0, 0]) evolves across =0,1,2,3,4\ell = 0, 1, 2, 3, 4. Is the change monotone, or does it oscillate?
  2. (Medium) Extend TransformerDecoder with a tie_weights=False flag and benchmark parameter counts against the tied version for V{1000,10000,50000}V \in \{1000, 10000, 50000\} at dmodel=512d_{model}=512. Confirm that tying saves exactly VdmodelV \cdot d_{model} parameters (plus VV if bias is included — which we disabled).
  3. (Hard) Replace the final nn.LayerNorm with nn.Identity and retrain on a small language-modeling task. Track the norm of hNh_N across training and describe what goes wrong. Why does pre-norm require a final LayerNorm while post-norm does not?

Next Section Preview

We now have a full encoder (ch07) and a full decoder (this section). Section 6 glues them into one Transformer module, walks through training-time teacher forcing vs inference-time autoregressive decoding, discusses shared-vs-separate vocabularies, and finishes with a parameter-count sanity check for the small Multi30k translation model we'll train in ch13.

Loading comments...