Chapter 8
18 min read
Section 42 of 75

Implementing TransformerDecoderLayer

Transformer Decoder

Introduction

Sections 1–3 introduced the three sublayers a decoder needs—masked self-attention, encoder-decoder cross-attention, and the position-wise feed-forward network—but nobody has yet assembled them into a single reusable module with a clean forward() signature. That is what this section does. We build TransformerDecoderLayer: one block that takes a target tensor and an encoder memory tensor, threads them through the three sublayers with residual Add & Norm wrappers, and returns a tensor of the exact same shape as its target input.

Same-shape-in / same-shape-out is the crucial property. It is what lets us stack NN copies of this layer in Section 5. If the output shape ever differed from the input shape, stacking would require a separate adapter between every pair of layers—and the Transformer recipe would fall apart.

Why this matters: every large-language-model you have heard of (GPT, Llama, Mistral, T5, PaLM) is, at its core, a stack of identical Transformer layers very close to this one. Understanding this single layer means understanding 95% of modern LLM architecture; the remaining 5% is vocabulary embedding, positional encoding, and a final projection head.

Recap: The Three Sublayers

A decoder layer has three sublayers, applied in order. Each one is wrapped in a residual add and a LayerNorm (jointly called Add & Norm).

  1. Masked self-attention — the decoder looks back at its own earlier target tokens. The causal mask Mij=M_{ij} = -\infty for j>ij > i blocks attention to the future (section 2).
  2. Cross-attention (encoder-decoder attention) — queries come from the decoder stream, keys and values come from the encoder output memory\text{memory} (section 3). This is the only place where source-side information enters the decoder.
  3. Position-wise feed-forward network FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2, applied independently to every target position. Provides per-token nonlinear capacity.

Notation reminder (shared across all of chapter 8): BB is the batch size, TsrcT_{src} and TtgtT_{tgt} are source / target lengths, dmodeld_{model} is the residual-stream width, HH is the number of heads, dk=dmodel/Hd_k = d_{model}/H, and dffd_{ff} is the FFN inner width.


DecoderLayer Forward Pass Math

In the original post-norm formulation (Vaswani et al., 2017), one decoder layer executes three residual-plus-norm updates:

  • x1=LayerNorm ⁣(x+Dropout(MaskedSelfAttn(x,tgt_mask)))x_1 = \mathrm{LayerNorm}\!\left(x + \mathrm{Dropout}\bigl(\mathrm{MaskedSelfAttn}(x,\, \text{tgt\_mask})\bigr)\right)
  • x2=LayerNorm ⁣(x1+Dropout(CrossAttn(x1,memory,memory_mask)))x_2 = \mathrm{LayerNorm}\!\left(x_1 + \mathrm{Dropout}\bigl(\mathrm{CrossAttn}(x_1,\, \text{memory},\, \text{memory\_mask})\bigr)\right)
  • x3=LayerNorm ⁣(x2+Dropout(FFN(x2)))x_3 = \mathrm{LayerNorm}\!\left(x_2 + \mathrm{Dropout}\bigl(\mathrm{FFN}(x_2)\bigr)\right)

Read each line as: "take the current residual stream, compute a sublayer update, dropout some of it, add it back, re-normalize". The xx‑prefixed vectors form a residual stream that flows through the layer with additive contributions from each sublayer. This is the same idea as ResNet (He et al., 2015): sublayers do not replace the representation, they add a correction to it. The identity path is always available, which is what keeps gradients healthy at depth.

Every sublayer function returns a tensor the same shape as its input—[B,Ttgt,dmodel][B, T_{tgt}, d_{model}]—so the residual add is always well-defined. LayerNorm re-centers each token vector independently across its dmodeld_{model} features: LayerNorm(z)i=γziμzσz2+ε+β\mathrm{LayerNorm}(z)_i = \gamma \cdot \frac{z_i - \mu_z}{\sqrt{\sigma^2_z + \varepsilon}} + \beta.


Pre-Norm vs Post-Norm

Where you place LayerNorm inside the sublayer wrapper is not a detail—it materially changes training stability at depth. Two families exist:

  • Post-norm (Vaswani 2017): y=LayerNorm(x+Sublayer(x))y = \mathrm{LayerNorm}(x + \mathrm{Sublayer}(x)). Norm is applied after the residual add.
  • Pre-norm (modern LLMs): y=x+Sublayer(LayerNorm(x))y = x + \mathrm{Sublayer}(\mathrm{LayerNorm}(x)). Norm is applied inside the sublayer branch; the residual path is never normalized.

Why does anyone prefer pre-norm? Xiong et al. (2020) "On Layer Normalization in the Transformer Architecture" showed that post-norm networks are hard to train past ~12 layers without a learning-rate warmup: the gradient magnitude at each sublayer depends multiplicatively on the sublayers above it, producing expected gradient norms that grow (or shrink) with depth. Pre-norm fixes this by keeping the residual path un-normalized—gradients flow straight through the addition, giving a clean gradient highway from the final layer back to the input embedding. The tradeoff: pre-norm networks are sometimes slightly weaker at small depth and can produce larger-magnitude activations deep in the stack, so many modern systems pair it with RMSNorm (Zhang & Sennrich, 2019) and careful initialization.

The practical consequence: GPT-2, GPT-3, Llama 1/2/3, Mistral, PaLM, Falcon, Qwen, and essentially every production LLM since ~2019 uses pre-norm. The original Vaswani post-norm survives in some translation systems and in educational code. We implement post-norm in this section because it matches the canonical equations you will see in the 2017 paper; switching to pre-norm is a two-line change (move each LayerNorm from LN(x+sub)\mathrm{LN}(x + \text{sub}) to x+sub(LN(x))x + \text{sub}(\mathrm{LN}(x))).


Dropout Placement

There are three places dropout enters a Transformer decoder layer. Each one serves a different regularization purpose:

LocationApplied toRationaleTypical rate
Attention dropoutSoftmax weights (inside MultiHeadAttention)Prevents over-reliance on any single key-query pair; smooths the attention distribution0.1 (0.0 in large LLMs)
Sublayer output dropoutOutput of each sublayer, before the residual addStochastic residuals — the layer output becomes a noisy estimate of the clean residual path, which regularizes the whole stack0.1
FFN internal dropoutBetween the two linears of the FFNRegularizes the wide d_ff hidden representation, which has more capacity than d_model0.1

The original paper used p=0.1p = 0.1 in all three positions. Modern large models trained on billions of tokens often set dropout to p=0.0p = 0.0—at that data scale, regularization is less needed and noise interferes with fitting. During inference ( .eval()), dropout is always a no-op.


Plain-Python Walkthrough

Before we wrap everything in nn.Module, let us run one decoder layer in pure NumPy. Every value is visible, every shape is explicit, and there is no autograd machinery hiding the logic. This is the same configuration B=1, Tsrc=4, Ttgt=3, dmodel=8, H=2, dff=16B=1,\ T_{src}=4,\ T_{tgt}=3,\ d_{model}=8,\ H=2,\ d_{ff}=16 that §5 and §6 will reuse.

One DecoderLayer forward — pure NumPy, seeded
🐍decoder_layer_numpy.py
1import numpy as np

NumPy gives us fast N-dimensional arrays and matrix multiplication (@). We deliberately use plain NumPy (not PyTorch) so every value is visible and there is no hidden autograd machinery — this is the from-scratch forward pass.

EXECUTION STATE
numpy = Numerical computing library. Provides np.ndarray, @ for matmul, broadcasting, and elementwise math used below.
as np = Standard alias so we write np.random, np.sqrt, np.exp instead of numpy.random, etc.
3np.random.seed(0)

Fixes the pseudo-random stream so every weight matrix and input tensor is reproducible. Same seed across sections 4, 5, 6 keeps numerics comparable.

EXECUTION STATE
📚 np.random.seed(s) = Initializes NumPy's global RNG. Every subsequent randn(...) call draws from the same deterministic sequence. Example: after seed(0), np.random.randn(2) always yields [1.7641, 0.4002].
5Shared demo config

Fixed across every ch08 plain-Python trace. Small enough to print every tensor; large enough that each concept (batching, heads, FFN expansion) exists.

EXECUTION STATE
B = 1 = Batch size. One sequence at a time.
T_src = 4 = Source length (number of encoder tokens in memory).
T_tgt = 3 = Target length (number of decoder tokens we are currently processing).
d_model = 8 = Residual-stream width (embedding dim).
H = 2 = Number of attention heads.
d_ff = 16 = FFN inner width. Typical ratio is 4 x d_model (here 4 x 8 = 32 would be canonical, but 16 keeps the printout readable).
6d_k = d_model // H

Per-head key / value / query width. The d_model residual stream is split evenly across H heads so total compute stays identical to single-head.

EXECUTION STATE
d_k = 4 = 8 // 2 = 4. Each head projects into a 4-dim subspace. d_v = d_k here (standard setup).
9x = np.random.randn(B, T_tgt, d_model) * 0.3

Simulated decoder input after token embedding + positional encoding + shift-right. 0.3 scaling gives roughly unit variance after LayerNorm and matches the ballpark of real embedding magnitudes.

EXECUTION STATE
📚 np.random.randn(*shape) = Draws from standard normal N(0, 1). Shape args here are (1, 3, 8) = 24 scalars.
x[0][0] = [ 0.5292 0.1200 0.2936 0.6723 0.5603 -0.2932 0.2850 -0.0454]
x[0][1] = [-0.0310 0.1232 0.0432 0.4363 0.2283 0.0365 0.1332 0.1001]
x[0][2] = [ 0.4482 -0.0615 0.0939 -0.2562 -0.7659 0.1961 0.2593 -0.2226]
shape(x) = (1, 3, 8) -> [B, T_tgt, d_model]
10memory = np.random.randn(B, T_src, d_model) * 0.3

Encoder output. Same d_model as x but a different sequence length (T_src = 4). Cross-attention uses this as keys and values.

EXECUTION STATE
memory[0][0] = [ 0.6809 -0.4363 0.0137 -0.0562 0.4598 0.4408 0.0465 0.1134]
shape(memory) = (1, 4, 8) -> [B, T_src, d_model]
13tgt_mask = np.triu(ones, k=1).astype(bool)

Upper-triangular boolean matrix. True entries mark (query_i, key_j) pairs where j > i — tokens the self-attention must NOT see. We combine this with scores via np.where(mask, -1e9, scores) so softmax kills the future.

EXECUTION STATE
📚 np.triu(A, k) = Keeps elements on or above the k-th diagonal and zeros everything below. k=1 means strictly above the main diagonal.
tgt_mask =
[[False  True  True]
 [False False  True]
 [False False False]]
purpose = Row i is the mask for query i: it forbids looking at any key j with j > i. Row 0 can only see key 0, row 1 keys 0-1, row 2 keys 0-2.
15def init_W(inp, out)

Xavier-ish init: weights ~ N(0, 1/inp). Keeps Q/K/V projections from blowing up variance at d_model = 8.

EXECUTION STATE
⬇ input: inp = Number of input features into the linear layer. For Q/K/V projections this equals d_model = 8.
⬇ input: out = Number of output features. For Q/K/V this is d_model (so we can split across heads); for the FFN it is d_ff.
⬆ return = np.ndarray of shape (inp, out) with each entry ~ N(0, 1/inp). Used with x @ W so that Var(output) ≈ Var(input).
16return np.random.randn(inp, out) * (1.0 / np.sqrt(inp))

Sample a standard-normal matrix and scale by 1/sqrt(inp). This is a simplified LeCun/Xavier uniform-variance init.

EXECUTION STATE
1.0 / np.sqrt(inp) = For inp=8: 1/sqrt(8) ≈ 0.3536. Each entry's std is 0.3536 instead of 1.0.
19Wq1, Wk1, Wv1, Wo1 — self-attention projections

Four learned matrices for the masked self-attention sublayer. Each is (d_model, d_model) = (8, 8). In a real PyTorch module these are nn.Linear(d_model, d_model) layers.

EXECUTION STATE
Wq1 shape = (8, 8) -> projects x into queries, then heads split it as (H, d_k) = (2, 4)
Wo1 shape = (8, 8) -> merges heads back to d_model
20Wq2, Wk2, Wv2, Wo2 — cross-attention projections

Separate set for cross-attention. Q comes from the decoder stream x1; K and V come from memory. Sharing weights with self-attn would couple two very different subspaces — so they are disjoint.

EXECUTION STATE
why separate = Self-attn queries target the target language; cross-attn queries target the source. Different tasks deserve different projections.
21W1, W2 — FFN projections

Two-layer MLP applied position-wise. W1: (8, 16) expands, W2: (16, 8) contracts.

EXECUTION STATE
W1 shape = (d_model, d_ff) = (8, 16)
W2 shape = (d_ff, d_model) = (16, 8)
23def multihead(Q_in, K_in, V_in, Wq, Wk, Wv, Wo, mask_bool=None)

One multi-head attention call. Used both for masked self-attention (Q=K=V=x) and for cross-attention (Q=x1, K=V=memory). The mask_bool argument turns it into masked or unmasked behavior.

EXECUTION STATE
⬇ input: Q_in = Source of queries. Shape [B, Tq, d_model]. For self-attn Tq = T_tgt. For cross-attn Tq = T_tgt but Q_in is x1, not memory.
⬇ input: K_in, V_in = Source of keys/values. Shape [B, Tk, d_model]. For self-attn Tk = T_tgt. For cross-attn Tk = T_src = 4.
⬇ input: Wq/Wk/Wv/Wo = Projection matrices, each (d_model, d_model). Kept separate so self-attn and cross-attn can learn different mappings.
⬇ input: mask_bool = Optional boolean mask. True entries get score = -1e9 before softmax, which makes softmax ≈ 0. Shape broadcasts to [B, H, Tq, Tk].
⬆ return = np.ndarray of shape [B, Tq, d_model] — one attended context vector per query position, merged across heads.
24Bsz, Tq, _ = Q_in.shape

Unpack query shape. Underscore discards the trailing d_model because it's redundant with the module constant.

EXECUTION STATE
Bsz = 1
Tq = 3 for self-attn, 3 for cross-attn (queries come from x1)
25_, Tk, _ = K_in.shape

Unpack key sequence length. Tk differs from Tq in cross-attention.

EXECUTION STATE
Tk = 3 for self-attn (K = x), 4 for cross-attn (K = memory)
26Q = (Q_in @ Wq).reshape(Bsz, Tq, H, d_k).transpose(0, 2, 1, 3)

Project to queries, split feature dim across heads, move the head axis forward. Result shape: [B, H, Tq, d_k].

EXECUTION STATE
📚 @ (matmul) = NumPy matrix multiply. Last two axes are matmul'd; leading batch dims broadcast. (1, 3, 8) @ (8, 8) -> (1, 3, 8).
📚 .reshape(Bsz, Tq, H, d_k) = Same 24 numbers, interpreted as (1, 3, 2, 4). The last axis d_model = 8 is split into H = 2 chunks of d_k = 4.
📚 .transpose(0, 2, 1, 3) = Permute axes from (B, Tq, H, d_k) to (B, H, Tq, d_k). Heads become the batch-like outer dim so each head's attention can be computed independently.
Q final shape = (1, 2, 3, 4) -> [B, H, Tq, d_k]
27K = (K_in @ Wk).reshape(...).transpose(...)

Same pattern as Q, but Tk may differ from Tq. In self-attn Tk = 3; in cross-attn Tk = 4.

EXECUTION STATE
K final shape = self-attn: (1, 2, 3, 4); cross-attn: (1, 2, 4, 4)
28V = (V_in @ Wv).reshape(...).transpose(...)

Values. Shares sequence length with K (both come from the same source). Different projection Wv means V is a different subspace than K.

EXECUTION STATE
V final shape = same as K
29scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)

Raw attention scores, per head. Swap the last two dims of K so matmul compatibility holds: Q(1,2,3,4) @ K^T(1,2,4,Tk) -> (1,2,3,Tk). Divide by sqrt(d_k)=2 to stabilize the softmax.

EXECUTION STATE
📚 K.transpose(0, 1, 3, 2) = Swap axes 3 and 2 of K. Shape (1, 2, Tk, 4) becomes (1, 2, 4, Tk).
np.sqrt(d_k) = sqrt(4) = 2.0. Scaling by 1/sqrt(d_k) keeps dot-product variance ~1 regardless of d_k (see ch02).
scores shape = self-attn: (1, 2, 3, 3); cross-attn: (1, 2, 3, 4)
30if mask_bool is not None:

Branch only executes for masked self-attention. Cross-attention here uses no mask; in real training a memory (source-padding) mask would also live on this branch.

31scores = np.where(mask_bool, -1e9, scores)

Replace disallowed positions with a huge negative number so exp(score) ≈ 0 after softmax. We use -1e9 instead of -inf to avoid NaN when an entire row is masked.

EXECUTION STATE
📚 np.where(cond, a, b) = Elementwise select. Where cond is True pick a, else pick b. Broadcasts mask_bool (1, 1, 3, 3) against scores (1, 2, 3, 3).
row 0 effect = Entries (0, 1) and (0, 2) are set to -1e9. After softmax they become ~0 so the first query only attends to key 0.
32scores -= scores.max(axis=-1, keepdims=True)

Numerical-stability trick: subtract the per-row max before exp. softmax(x) = softmax(x - c) for any constant c, so the distribution is unchanged but exp() can no longer overflow.

EXECUTION STATE
axis=-1 = Reduce along the last axis (keys). Each (batch, head, query) row gets its own max.
keepdims=True = Keep axis as size 1 so scores (., ., ., Tk) - max (., ., ., 1) broadcasts cleanly.
33w = np.exp(scores); w /= w.sum(axis=-1, keepdims=True)

Standard softmax over the key axis. Result rows sum to 1. For self-attn row 0 is a one-hot on key 0 thanks to the mask.

EXECUTION STATE
📚 np.exp(x) = Elementwise e^x. e^(-1e9) ≈ 0, so masked entries vanish.
w row-sums = Each row over keys sums to 1.0 (within float tolerance).
w shape = (B, H, Tq, Tk) — attention weight tensor
34ctx = (w @ V).transpose(0, 2, 1, 3).reshape(Bsz, Tq, d_model)

Weighted sum of values, then collapse heads back into d_model. The transpose puts (B, Tq, H, d_k) order so the final reshape merges H and d_k contiguously.

EXECUTION STATE
📚 w @ V = Shape (1, 2, Tq, Tk) @ (1, 2, Tk, d_k) -> (1, 2, Tq, d_k). Each head produces one context vector per query.
📚 .reshape(Bsz, Tq, d_model) = (1, Tq, 2, 4) -> (1, Tq, 8). Heads concatenated along the feature axis.
ctx shape = (1, 3, 8) -> [B, Tq, d_model]
35return ctx @ Wo

Final linear mixing across heads. Wo is a learned (d_model, d_model) that blends the concatenated head outputs. This is what makes multi-head non-trivially different from H independent attentions.

EXECUTION STATE
⬆ return = np.ndarray (B, Tq, d_model). For our run, the first row of self-attn output: [-0.2127 -0.0391 0.2221 -0.0998 0.1540 -0.0465 -0.0045 -0.1185]
37def layernorm(z, eps=1e-5)

Per-token standardization across the feature axis. gamma and beta are left as 1 and 0 here so we can focus on the residual structure — a real implementation adds those learnable affine parameters.

EXECUTION STATE
⬇ input: z = Tensor of shape (B, T, d_model). Each (b, t) vector across d_model will be re-centered and re-scaled independently.
⬇ input: eps = Small positive constant (1e-5) added inside the sqrt to avoid division by zero when var = 0.
⬆ return = z with mean 0 and variance 1 along the last axis.
38mu = z.mean(axis=-1, keepdims=True)

Per-token mean across features. Shape collapses the feature dim to 1 so we can subtract with broadcasting.

EXECUTION STATE
axis=-1 = Average the 8 feature values for each (B, T) position independently.
mu shape = (B, T, 1)
39var = z.var(axis=-1, keepdims=True)

Per-token variance across features. NumPy uses the biased estimator (divide by N), matching PyTorch's default LayerNorm.

40return (z - mu) / np.sqrt(var + eps)

Standardize each token vector. Without eps a constant-input row would produce 0/0 = NaN.

43mask1 = tgt_mask.reshape(1, 1, T_tgt, T_tgt)

Add batch and head axes so the (3, 3) bool mask broadcasts against (B, H, 3, 3) scores.

EXECUTION STATE
broadcast target = scores shape (1, 2, 3, 3). Mask shape (1, 1, 3, 3) tiles across heads.
44sa = multihead(x, x, x, Wq1, Wk1, Wv1, Wo1, mask_bool=mask1)

Masked self-attention. Q, K, V all come from the decoder input x; the causal mask forbids looking at future target tokens.

EXECUTION STATE
sa[0][0] (first target token) = [-0.2127 -0.0391 0.2221 -0.0998 0.1540 -0.0465 -0.0045 -0.1185]
sa[0][1] = [-0.2055 -0.0342 0.1184 -0.0020 -0.0202 0.0380 -0.0353 -0.1949]
sa[0][2] = [ 0.1049 0.0542 0.0709 0.0001 0.1344 -0.0098 -0.0974 0.2219]
sa shape = (1, 3, 8)
45x1 = layernorm(x + sa)

Residual add + LayerNorm (POST-norm). x1 is the input to the next sublayer. Notice how LayerNorm re-centers each 8-dim vector so values are roughly in [-2, 2].

EXECUTION STATE
x1[0][0] = [ 0.2021 -0.4840 0.7825 0.9476 1.3605 -1.7090 0.0973 -1.1971]
x1[0][1] = [-1.7585 -0.0152 0.3745 1.8351 0.6235 -0.0927 0.0327 -0.9994]
x1[0][2] = [ 1.6268 -0.0876 0.4391 -0.8486 -1.9970 0.5046 0.4303 -0.0676]
mean of each row = ≈ 0.0
std of each row = ≈ 1.0
46ca = multihead(x1, memory, memory, Wq2, Wk2, Wv2, Wo2, mask_bool=None)

Cross-attention. Queries from x1 (decoder stream); keys and values from encoder memory. No causal mask because decoder positions are allowed to look at all source positions.

EXECUTION STATE
ca[0][0] = [-0.2662 -0.3756 -0.2223 0.1929 -0.0257 0.2025 0.1609 0.3490]
ca[0][1] = [-0.2698 -0.3539 -0.2229 0.2040 -0.0173 0.1984 0.1624 0.3287]
ca[0][2] = [-0.3006 -0.2777 -0.1210 0.0892 0.1396 0.2126 0.2194 0.2276]
47x2 = layernorm(x1 + ca)

Second residual + LayerNorm. x2 carries a blend of decoder history (via x1) and source information (via ca).

EXECUTION STATE
x2[0][0] = [-0.0696 -0.9085 0.5886 1.2006 1.4055 -1.5906 0.2703 -0.8963]
x2[0][1] = [-1.8849 -0.3458 0.1372 1.8880 0.5589 0.0946 0.1775 -0.6256]
x2[0][2] = [ 1.3947 -0.4164 0.3153 -0.8384 -2.0141 0.7426 0.6703 0.1460]
48ff = np.maximum(0, x2 @ W1) @ W2

Position-wise 2-layer MLP with ReLU. Applied identically to every target position. Expands 8 -> 16, applies ReLU, contracts 16 -> 8.

EXECUTION STATE
📚 np.maximum(0, z) = Elementwise max(0, z_ij) = ReLU. Zeros negative entries.
ff[0][0] = [ 0.4177 -0.1750 -0.3390 -0.3945 -0.4040 -0.8389 0.5337 0.1462]
49out = layernorm(x2 + ff)

Final residual + LayerNorm. This is the layer output — same shape as x, ready to feed into the next decoder layer.

EXECUTION STATE
out[0][0] = [ 0.4296 -0.8520 0.3414 0.8396 1.0145 -2.0569 0.8376 -0.5536]
out[0][1] = [-1.8868 0.3195 -0.1563 1.5612 0.8544 -0.5559 0.5753 -0.7114]
out[0][2] = [ 1.3633 0.0677 0.4245 -0.9663 -2.0952 0.4925 0.6174 0.0961]
out shape = (1, 3, 8) -> identical to x
51print(out.shape)

Prints (1, 3, 8). Same-shape-in-same-shape-out is what lets us stack N copies of this layer in §5.

EXECUTION STATE
⬆ console = (1, 3, 8)
15 lines without explanation
1import numpy as np
2
3np.random.seed(0)
4
5# Shared demo config (same across ch08 sections)
6B, T_src, T_tgt, d_model, H, d_ff = 1, 4, 3, 8, 2, 16
7d_k = d_model // H  # 4
8
9# Decoder input x and encoder memory (shifted-right + pos-encoded embeddings)
10x = np.random.randn(B, T_tgt, d_model) * 0.3        # [1, 3, 8]
11memory = np.random.randn(B, T_src, d_model) * 0.3   # [1, 4, 8]
12
13# Causal mask for target self-attention. True = disallowed (future).
14tgt_mask = np.triu(np.ones((T_tgt, T_tgt), dtype=bool), k=1)
15
16def init_W(inp, out):
17    return np.random.randn(inp, out) * (1.0 / np.sqrt(inp))
18
19# Three sets of attention weights (self-attn, cross-attn) and one FFN
20Wq1, Wk1, Wv1, Wo1 = init_W(d_model, d_model), init_W(d_model, d_model), init_W(d_model, d_model), init_W(d_model, d_model)
21Wq2, Wk2, Wv2, Wo2 = init_W(d_model, d_model), init_W(d_model, d_model), init_W(d_model, d_model), init_W(d_model, d_model)
22W1, W2 = init_W(d_model, d_ff), init_W(d_ff, d_model)
23
24def multihead(Q_in, K_in, V_in, Wq, Wk, Wv, Wo, mask_bool=None):
25    Bsz, Tq, _ = Q_in.shape
26    _, Tk, _ = K_in.shape
27    Q = (Q_in @ Wq).reshape(Bsz, Tq, H, d_k).transpose(0, 2, 1, 3)
28    K = (K_in @ Wk).reshape(Bsz, Tk, H, d_k).transpose(0, 2, 1, 3)
29    V = (V_in @ Wv).reshape(Bsz, Tk, H, d_k).transpose(0, 2, 1, 3)
30    scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)
31    if mask_bool is not None:
32        scores = np.where(mask_bool, -1e9, scores)
33    scores -= scores.max(axis=-1, keepdims=True)
34    w = np.exp(scores); w /= w.sum(axis=-1, keepdims=True)
35    ctx = (w @ V).transpose(0, 2, 1, 3).reshape(Bsz, Tq, d_model)
36    return ctx @ Wo
37
38def layernorm(z, eps=1e-5):
39    mu = z.mean(axis=-1, keepdims=True)
40    var = z.var(axis=-1, keepdims=True)
41    return (z - mu) / np.sqrt(var + eps)
42
43# --- DecoderLayer forward (POST-NORM, Vaswani 2017) ---
44mask1 = tgt_mask.reshape(1, 1, T_tgt, T_tgt)
45sa = multihead(x,  x,      x,      Wq1, Wk1, Wv1, Wo1, mask_bool=mask1)   # masked self-attn
46x1 = layernorm(x + sa)                                                    # Add & Norm
47ca = multihead(x1, memory, memory, Wq2, Wk2, Wv2, Wo2, mask_bool=None)    # cross-attn
48x2 = layernorm(x1 + ca)                                                   # Add & Norm
49ff = np.maximum(0, x2 @ W1) @ W2                                          # position-wise FFN
50out = layernorm(x2 + ff)                                                  # Add & Norm
51
52print(out.shape)   # -> (1, 3, 8) — same shape as input x

Shape table

The five tensors a DecoderLayer actually touches:

TensorShapeRole
tgt[B, T_tgt, d_model]Decoder input (shifted-right embeddings)
memory[B, T_src, d_model]Encoder output (keys and values for cross-attn)
tgt_mask[T_tgt, T_tgt]Causal mask — blocks attention to future target tokens
memory_mask[B, T_src]Optional source-padding mask
out[B, T_tgt, d_model]Layer output — same shape as tgt so layers can stack

PyTorch Implementation

The PyTorch version is a straight transcription of the NumPy trace above, with three upgrades: (1) parameters live in nn.Linear layers that autograd tracks, (2) LayerNorm and Dropout are proper nn.Module submodules so .eval() disables dropout automatically, and (3) we inherit from nn.Module so the layer can be composed, saved, moved to GPU, and wrapped in torch.compile like any other PyTorch module.

We include MultiHeadAttention and PositionwiseFeedForward inline for completeness—in a real project these would be imported from chapter03 and chapter06 respectively.

TransformerDecoderLayer — PyTorch class
🐍decoder_layer.py
1import math

Python's standard math module. We use math.sqrt(d_k) for the attention scaling factor — a scalar constant, so math is slightly cheaper than torch.sqrt here.

2from typing import Optional

Lets us annotate parameters that may be None, like masks. Clearer than writing tgt_mask: torch.Tensor = None — Optional[T] means 'a T or None'.

EXECUTION STATE
Optional[torch.Tensor] = Equivalent to Union[torch.Tensor, None]. Tools like mypy/pyright check that callers can pass either a tensor or None.
4import torch

PyTorch. Gives us torch.Tensor, matmul, softmax, LayerNorm, autograd, GPU support.

5import torch.nn as nn

nn module — contains Module, Linear, LayerNorm, Dropout, MultiheadAttention. We use nn.Linear and nn.LayerNorm directly so our layer uses standard learned parameters.

6import torch.nn.functional as F

Stateless ops: F.softmax, F.relu, F.dropout. Use F.* when there's no learnable state; use nn.* when we need stored parameters (like nn.Dropout holding its rate).

EXECUTION STATE
F vs nn = nn.Dropout(p) is a Module — trainable state, toggles with .train()/.eval(). F.dropout() is a plain function — you must pass training=True manually. We use nn.Dropout so .eval() disables it automatically.
9class MultiHeadAttention(nn.Module):

Multi-head attention as a reusable Module. Both self-attn and cross-attn use the same class; the caller decides whether Q, K, V come from the same or different tensors.

12def __init__(self, d_model, num_heads, dropout=0.1)

Constructor. Stores config and creates the four projection matrices + dropout.

EXECUTION STATE
⬇ input: d_model = Residual-stream width. Typical 512 (base Transformer), 768 (GPT-2 small), 4096 (Llama 7B).
⬇ input: num_heads = Number of parallel attention heads. Must evenly divide d_model. Typical 8 or 12.
⬇ input: dropout = Probability applied to softmax weights (attention dropout). Original paper used 0.1; modern LLMs often use 0.0 in large-scale pretraining.
13super().__init__()

Required for nn.Module subclasses. Initializes internal state (parameter dicts, module dicts). Forgetting this is the #1 PyTorch bug.

14assert d_model % num_heads == 0, ...

Hard-fail if d_model can't be evenly split across heads. An integer check beats silently computing wrong shapes at runtime.

15self.d_model = d_model

Store for use in forward() (needed for the final reshape).

16self.num_heads = num_heads

Stored for reshape in forward().

17self.d_k = d_model // num_heads

Per-head width. For d_model=512, num_heads=8 this gives d_k=64.

EXECUTION STATE
// operator = Floor division in Python. Returns an int when both operands are ints. d_model // num_heads guarantees an integer d_k.
18self.W_q = nn.Linear(d_model, d_model)

Query projection. Goes from (B, T, d_model) to (B, T, d_model). Learnable weight + bias inside.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias=True) = Affine map y = x @ W^T + b. Stores W of shape (out_features, in_features) and b of shape (out_features,). Forward uses the transpose automatically.
shape(W) = (d_model, d_model)
19self.W_k = nn.Linear(d_model, d_model)

Key projection. Separate parameters from W_q so K encodes 'what I contain', distinct from Q's 'what I'm looking for'.

20self.W_v = nn.Linear(d_model, d_model)

Value projection. Holds the content returned when a query matches.

21self.W_o = nn.Linear(d_model, d_model)

Output projection. Mixes the concatenated head outputs back into d_model. Without W_o, multi-head attention collapses to H independent attentions glued together with no cross-head interaction.

22self.attn_dropout = nn.Dropout(dropout)

Dropout on the attention weights (after softmax, before multiplying V). Randomly zeros entries of the attention matrix during training.

24def forward(self, query, key, value, mask=None)

Public call. query is the source of Q; key and value are the source of K and V. For self-attention the caller passes the same tensor three times; for cross-attention query=decoder, key=value=encoder memory.

EXECUTION STATE
⬇ input: query = Shape (B, T_q, d_model). For the decoder's self-attn: the decoder's own previous state. For cross-attn: x1 (post first Add&Norm).
⬇ input: key = Shape (B, T_k, d_model). Same sequence as value. For cross-attn this is the encoder output memory.
⬇ input: value = Shape (B, T_k, d_model). Content to return. Typically key and value come from the same source.
⬇ input: mask = Optional boolean-ish tensor. True / 1 = keep. False / 0 = block. Broadcasts to scores shape (B, H, T_q, T_k).
⬆ return = (B, T_q, d_model). Same shape as query — attention preserves the query's sequence length.
30B, Tq, _ = query.shape

Unpack batch and query-length. d_model is dropped because we already stored it in self.d_model.

EXECUTION STATE
query.shape = torch.Size([B, T_q, d_model])
31Tk = key.size(1)

Key sequence length. Not the same as Tq for cross-attention. .size(1) is equivalent to .shape[1].

32Q = self.W_q(query).view(B, Tq, self.num_heads, self.d_k).transpose(1, 2)

Project to queries, split the feature axis into (H, d_k), move heads ahead of the sequence dim.

EXECUTION STATE
📚 .view(...) = Reinterprets the tensor's memory with a new shape. Requires the tensor to be contiguous. Cheaper than reshape when it works.
📚 .transpose(1, 2) = Swap dims 1 and 2. (B, Tq, H, d_k) becomes (B, H, Tq, d_k).
Q final shape = (B, H, Tq, d_k)
33K = self.W_k(key).view(B, Tk, ...).transpose(1, 2)

Key projection uses Tk (not Tq) in the view. For self-attn Tk == Tq; for cross-attn they differ.

34V = self.W_v(value).view(B, Tk, ...).transpose(1, 2)

Value projection. Same sequence length as K.

35scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

Compute scaled dot-product scores per head.

EXECUTION STATE
📚 torch.matmul = Batched matrix multiply. Last two axes matmul'd, leading dims broadcast. (B, H, Tq, d_k) x (B, H, d_k, Tk) -> (B, H, Tq, Tk).
📚 K.transpose(-2, -1) = Swap the last two dims of K. (B, H, Tk, d_k) becomes (B, H, d_k, Tk). Required for matmul shape compatibility.
math.sqrt(self.d_k) = For d_k=64, equals 8.0. Keeps the dot-product variance ~ 1 regardless of d_k.
36if mask is not None:

Only apply mask if one was passed. For unmasked sublayers (rare, e.g. encoder self-attn during full-batch training) we skip.

37scores = scores.masked_fill(mask == 0, float("-inf"))

Replace blocked positions with -inf so softmax sends them to zero. The convention used here: mask == 1 allows, mask == 0 blocks.

EXECUTION STATE
📚 .masked_fill(mask_bool, value) = Wherever mask_bool is True, replace the entry in scores with value. Broadcasts across heads.
-inf vs -1e9 = float('-inf') is exact — softmax maps it to 0 with no numerical residue. Works fine in FP32; in FP16 it can cause NaN if an entire row is masked. Many production models use -1e4 in FP16.
38attn = F.softmax(scores, dim=-1)

Normalize each query's scores over keys into a probability distribution.

EXECUTION STATE
📚 F.softmax(x, dim) = softmax(x_i) = exp(x_i) / sum_j exp(x_j). Numerically stable: PyTorch subtracts per-slice max internally.
dim=-1 = Normalize across the KEY axis. Each (batch, head, query) row sums to 1.
39attn = self.attn_dropout(attn)

Zero out a fraction p of attention weights during training. During eval() this is a no-op. Renormalization is done by PyTorch (scale remaining weights by 1/(1-p)) so the expectation is preserved.

40ctx = torch.matmul(attn, V)

Weighted sum of values. Shape: (B, H, Tq, Tk) x (B, H, Tk, d_k) -> (B, H, Tq, d_k). Each head outputs d_k features per query.

41ctx = ctx.transpose(1, 2).contiguous().view(B, Tq, self.d_model)

Merge heads back into d_model.

EXECUTION STATE
📚 .contiguous() = After transpose, the memory layout is non-contiguous. .view() requires contiguous memory, so .contiguous() materializes a fresh tensor with the standard row-major layout.
view(B, Tq, d_model) = (B, Tq, H, d_k) -> (B, Tq, H*d_k) = (B, Tq, d_model).
42return self.W_o(ctx)

Final linear mixing across heads. Without this, the H head outputs would stay in disjoint feature subspaces.

EXECUTION STATE
⬆ return = (B, Tq, d_model) tensor
45class PositionwiseFeedForward(nn.Module):

Position-wise 2-layer MLP — applied identically to every sequence position, not across positions. The same weights see every token.

48def __init__(self, d_model, d_ff, dropout=0.1)

Constructor. d_ff is the inner width; standard ratio is 4 x d_model.

EXECUTION STATE
⬇ input: d_model = Input/output width (residual stream width).
⬇ input: d_ff = Inner hidden width. 4 x d_model is canonical (Vaswani et al.). Llama uses 4 x with SwiGLU's 2/3 factor adjustment.
⬇ input: dropout = Applied between the two linear layers.
49super().__init__()

Initialize the nn.Module base class.

50self.linear1 = nn.Linear(d_model, d_ff)

Expansion layer. For d_model=512, d_ff=2048: weight is (2048, 512), bias is (2048,).

51self.linear2 = nn.Linear(d_ff, d_model)

Contraction layer. Completes the bottleneck-in-reverse (wide middle).

52self.dropout = nn.Dropout(dropout)

Dropout between the two linears. Regularizes the overcomplete d_ff hidden activations.

54def forward(self, x)

Applies linear1 -> ReLU -> Dropout -> linear2 in one composed call.

EXECUTION STATE
⬇ input: x = (B, T, d_model) — same shape as decoder residual stream.
⬆ return = (B, T, d_model) — shape preserved; content transformed.
55return self.linear2(self.dropout(F.relu(self.linear1(x))))

Inside-out execution order: linear1(x) -> ReLU -> Dropout -> linear2. Read it like function composition.

EXECUTION STATE
📚 F.relu(z) = max(0, z) elementwise. Creates the nonlinearity the transformer relies on.
58class TransformerDecoderLayer(nn.Module):

The module we are building in this section. Wraps the three sublayers with residual Add + LayerNorm + sublayer-output dropout, following the Vaswani 2017 (post-norm) convention.

64def __init__(self, d_model, num_heads, d_ff, dropout=0.1)

Top-level constructor. Signature mirrors nn.TransformerDecoderLayer from torch.nn.

EXECUTION STATE
⬇ input: d_model = Residual width — must match the encoder so cross-attn shapes line up.
⬇ input: num_heads = H — number of heads. Must divide d_model.
⬇ input: d_ff = FFN inner width.
⬇ input: dropout = Applied in three places: inside each attention (on softmax weights), inside the FFN, and on each sublayer output before the residual add.
70super().__init__()

Base-class init (required).

71self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)

Masked self-attention instance. Its parameters are disjoint from the cross-attn instance below.

72self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)

Cross-attention instance. Same class, separate weights — lets the layer learn different Q/K/V subspaces for 'attend to past target tokens' vs 'attend to source tokens'.

73self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)

Position-wise FFN instance.

74self.norm1 = nn.LayerNorm(d_model)

LayerNorm after the first residual add. Per-token normalization across features.

EXECUTION STATE
📚 nn.LayerNorm(d_model) = Normalizes over the last dim of size d_model. Maintains two learnable buffers: gamma (scale, init 1) and beta (shift, init 0), each of shape (d_model,).
75self.norm2 = nn.LayerNorm(d_model)

LayerNorm after cross-attn's residual.

76self.norm3 = nn.LayerNorm(d_model)

LayerNorm after the FFN's residual.

77self.drop1 = nn.Dropout(dropout)

Dropout on the self-attn sublayer's output, applied before the residual add.

78self.drop2 = nn.Dropout(dropout)

Dropout on the cross-attn sublayer's output.

79self.drop3 = nn.Dropout(dropout)

Dropout on the FFN sublayer's output.

81def forward(self, tgt, memory, tgt_mask=None, memory_mask=None)

One layer's forward pass. Matches the signature of torch.nn.TransformerDecoderLayer so you can drop this class in as a replacement.

EXECUTION STATE
⬇ input: tgt = (B, T_tgt, d_model) — decoder input for this layer. For the first layer this is embed(y) + PE; for deeper layers it is the previous layer's output.
⬇ input: memory = (B, T_src, d_model) — encoder output. Same tensor is passed to every decoder layer.
⬇ input: tgt_mask = Optional mask over target-target attention. Usually a causal lower-triangular mask of shape (T_tgt, T_tgt).
⬇ input: memory_mask = Optional mask over target-source attention. Typically a source-padding mask of shape (B, T_src) broadcast as (B, 1, 1, T_src).
⬆ return = (B, T_tgt, d_model) — same shape as tgt.
88# Sublayer 1: masked self-attention + Add & Norm

Comment marking the first of three sublayers. Order matters: self-attn must come before cross-attn because cross-attn uses the refined decoder state (x1), not the raw tgt.

89sa = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)

All three of Q/K/V sources are tgt — this is self-attention. The causal tgt_mask forbids seeing future target tokens.

EXECUTION STATE
sa shape = (B, T_tgt, d_model)
90x1 = self.norm1(tgt + self.drop1(sa))

Post-norm Add & Norm: Dropout(sa) -> add residual tgt -> LayerNorm. x1 is the input to sublayer 2.

91# Sublayer 2: cross-attention + Add & Norm

Marker for the bridge-to-encoder sublayer. Queries come from the refined decoder state x1; keys and values come from memory.

92ca = self.cross_attn(x1, memory, memory, mask=memory_mask)

Cross-attention. Note the asymmetry: Q source (x1) has length T_tgt, K/V source (memory) has length T_src.

EXECUTION STATE
ca shape = (B, T_tgt, d_model)
93x2 = self.norm2(x1 + self.drop2(ca))

Second residual + LayerNorm. x2 carries both target history and source context.

94# Sublayer 3: FFN + Add & Norm

Third sublayer marker. The FFN adds per-token nonlinear capacity — it does not mix across positions.

95ff = self.ffn(x2)

Apply position-wise FFN. Same shape in, same shape out. Inside: Linear -> ReLU -> Dropout -> Linear.

96return self.norm3(x2 + self.drop3(ff))

Final Add & Norm. The return value is the decoder-layer output — you can either stack another copy of this layer on top of it (§5) or send it to the output projection (§6).

EXECUTION STATE
⬆ return = (B, T_tgt, d_model) — identical shape to tgt, ready for the next layer.
37 lines without explanation
1import math
2from typing import Optional
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7
8
9class MultiHeadAttention(nn.Module):
10    """Multi-head attention reused by both self-attn and cross-attn sublayers."""
11
12    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
13        super().__init__()
14        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
15        self.d_model = d_model
16        self.num_heads = num_heads
17        self.d_k = d_model // num_heads
18        self.W_q = nn.Linear(d_model, d_model)
19        self.W_k = nn.Linear(d_model, d_model)
20        self.W_v = nn.Linear(d_model, d_model)
21        self.W_o = nn.Linear(d_model, d_model)
22        self.attn_dropout = nn.Dropout(dropout)
23
24    def forward(
25        self,
26        query: torch.Tensor,
27        key: torch.Tensor,
28        value: torch.Tensor,
29        mask: Optional[torch.Tensor] = None,
30    ) -> torch.Tensor:
31        B, Tq, _ = query.shape
32        Tk = key.size(1)
33        Q = self.W_q(query).view(B, Tq, self.num_heads, self.d_k).transpose(1, 2)
34        K = self.W_k(key).view(B, Tk, self.num_heads, self.d_k).transpose(1, 2)
35        V = self.W_v(value).view(B, Tk, self.num_heads, self.d_k).transpose(1, 2)
36        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
37        if mask is not None:
38            scores = scores.masked_fill(mask == 0, float("-inf"))
39        attn = F.softmax(scores, dim=-1)
40        attn = self.attn_dropout(attn)
41        ctx = torch.matmul(attn, V)
42        ctx = ctx.transpose(1, 2).contiguous().view(B, Tq, self.d_model)
43        return self.W_o(ctx)
44
45
46class PositionwiseFeedForward(nn.Module):
47    """FFN(x) = max(0, x W_1 + b_1) W_2 + b_2 with dropout between."""
48
49    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
50        super().__init__()
51        self.linear1 = nn.Linear(d_model, d_ff)
52        self.linear2 = nn.Linear(d_ff, d_model)
53        self.dropout = nn.Dropout(dropout)
54
55    def forward(self, x: torch.Tensor) -> torch.Tensor:
56        return self.linear2(self.dropout(F.relu(self.linear1(x))))
57
58
59class TransformerDecoderLayer(nn.Module):
60    """
61    One decoder layer = masked self-attn  ->  cross-attn  ->  FFN.
62    Each sublayer wrapped in Dropout + residual Add + LayerNorm (post-norm).
63    """
64
65    def __init__(
66        self,
67        d_model: int,
68        num_heads: int,
69        d_ff: int,
70        dropout: float = 0.1,
71    ):
72        super().__init__()
73        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
74        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
75        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
76        self.norm1 = nn.LayerNorm(d_model)
77        self.norm2 = nn.LayerNorm(d_model)
78        self.norm3 = nn.LayerNorm(d_model)
79        self.drop1 = nn.Dropout(dropout)
80        self.drop2 = nn.Dropout(dropout)
81        self.drop3 = nn.Dropout(dropout)
82
83    def forward(
84        self,
85        tgt: torch.Tensor,
86        memory: torch.Tensor,
87        tgt_mask: Optional[torch.Tensor] = None,
88        memory_mask: Optional[torch.Tensor] = None,
89    ) -> torch.Tensor:
90        # Sublayer 1: masked self-attention + Add & Norm
91        sa = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
92        x1 = self.norm1(tgt + self.drop1(sa))
93        # Sublayer 2: cross-attention + Add & Norm
94        ca = self.cross_attn(x1, memory, memory, mask=memory_mask)
95        x2 = self.norm2(x1 + self.drop2(ca))
96        # Sublayer 3: FFN + Add & Norm
97        ff = self.ffn(x2)
98        return self.norm3(x2 + self.drop3(ff))

Using the layer

Instantiate, set to eval mode, build a causal mask, and call forward. Note the mask convention used by our module: 1 = keep, 0 = block. The standard torch.triu helper returns a "block future" mask, so we invert it with ~ before passing it in.

Running one DecoderLayer end-to-end
🐍run_decoder_layer.py
1import torch

Loads PyTorch.

3torch.manual_seed(0)

Deterministic parameter init and deterministic torch.randn. Same seed as §5 and §6 demos.

5Shared demo config

Same constants as the plain-Python trace so values are comparable.

EXECUTION STATE
B = 1
T_src = 4
T_tgt = 3
d_model = 8
H = 2
d_ff = 16
7layer = TransformerDecoderLayer(d_model=8, num_heads=2, d_ff=16, dropout=0.0)

Instantiate one layer. dropout=0.0 for the demo so the output is reproducible — in real training set it to 0.1.

8layer.eval()

Switches all contained nn.Dropout modules to pass-through. Also affects BatchNorm / LayerNorm behavior (not LayerNorm though — it's identical in train and eval). Always call eval() before a deterministic forward.

EXECUTION STATE
📚 Module.eval() = Sets self.training = False recursively on all submodules. Dropout becomes identity; BatchNorm uses running stats; LayerNorm is unchanged.
10tgt = torch.randn(B, T_tgt, d_model)

Sample a decoder input. Shape (1, 3, 8).

EXECUTION STATE
📚 torch.randn(*shape) = Standard normal samples. After manual_seed(0), tgt[0][0] ≈ [-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152].
11memory = torch.randn(B, T_src, d_model)

Fake encoder output. Shape (1, 4, 8).

12causal = torch.triu(torch.ones(T_tgt, T_tgt, dtype=torch.bool), diagonal=1)

Upper-triangular True mask: entry (i, j) is True iff j > i. We will INVERT this to get 'allowed' positions.

EXECUTION STATE
📚 torch.triu(A, diagonal=k) = Zero out everything below the k-th diagonal. diagonal=1 keeps entries strictly above the main diagonal.
causal =
[[False  True  True]
 [False False  True]
 [False False False]]
13tgt_mask = (~causal).unsqueeze(0).unsqueeze(0)

Convention in our module: mask == 1 (True) ALLOWS attention, mask == 0 (False) BLOCKS. So we invert causal and add batch + head axes for broadcasting.

EXECUTION STATE
📚 ~ (logical not) = Bitwise NOT on a bool tensor. Flips True<->False elementwise.
📚 .unsqueeze(dim) = Insert a size-1 axis at dim. Two unsqueezes take shape (3,3) -> (1,3,3) -> (1,1,3,3) so the mask broadcasts against scores of shape (B, H, T, T).
tgt_mask[0][0] =
[[ True False False]
 [ True  True False]
 [ True  True  True]]
15with torch.no_grad():

Disables autograd bookkeeping so no computation graph is built. Saves memory and runs faster — good for inference-only forwards.

EXECUTION STATE
📚 torch.no_grad() = Context manager. Inside, every tensor op sets requires_grad=False on outputs and skips graph recording.
16out = layer(tgt, memory, tgt_mask=tgt_mask)

Call the layer's forward(). Positional args tgt and memory; tgt_mask as kwarg. memory_mask omitted (None).

EXECUTION STATE
out shape = torch.Size([1, 3, 8])
18print(out.shape)

Confirms the layer is shape-preserving — the core invariant we need to stack N copies in §5.

EXECUTION STATE
⬆ console = torch.Size([1, 3, 8])
6 lines without explanation
1import torch
2
3torch.manual_seed(0)
4
5B, T_src, T_tgt, d_model, H, d_ff = 1, 4, 3, 8, 2, 16
6
7layer = TransformerDecoderLayer(d_model=d_model, num_heads=H, d_ff=d_ff, dropout=0.0)
8layer.eval()  # turn off dropout so the forward is deterministic
9
10tgt = torch.randn(B, T_tgt, d_model)        # decoder input
11memory = torch.randn(B, T_src, d_model)     # encoder output
12causal = torch.triu(torch.ones(T_tgt, T_tgt, dtype=torch.bool), diagonal=1)
13tgt_mask = (~causal).unsqueeze(0).unsqueeze(0)   # True = keep, shape (1,1,T,T)
14
15with torch.no_grad():
16    out = layer(tgt, memory, tgt_mask=tgt_mask)
17
18print(out.shape)   # torch.Size([1, 3, 8])
Sanity check: if you replace our TransformerDecoderLayer with PyTorch's built-in nn.TransformerDecoderLayer(d_model=8, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), the forward signature is nearly identical. The differences: PyTorch uses nhead and dim_feedforward spellings, and its attn_mask convention is "True = block" (the opposite of ours). Always check your library's mask polarity before wiring things together.

Interactive Visualization

Click any of the three sublayer boxes to highlight it and see its governing equation. Hover an arrow to read the tensor shape it carries. Toggle between Post-Norm (Vaswani 2017) and Pre-Norm (Llama, GPT-3, Mistral) to watch the LayerNorm badge move from after the residual add to inside the sublayer branch.

DecoderLayer Flow
Click a sublayer or tab through it. Hover arrows to see tensor shapes.
tgtmemory (from encoder)Masked Self-Attentiontgt_maskx_1LNCross-Attentionx_2LNFeed-Forward NetworkoutLNPost-Norm: y = LN(x + Sublayer(x))
Active sublayer
Masked Self-Attention
x_1 = x + Dropout(MaskedSelfAttn(LN(x), tgt_mask))

Each target token looks back at earlier target tokens only. The causal tgt_mask forbids attending to the future.

Norm placement: LayerNorm is applied after the residual Add.

In Modern Systems

The layer we just built is the Vaswani-2017 blueprint. Every production Transformer since then is a small set of edits on top of it:

  • Llama 1/2/3, Mistral, Qwen, Falcon use pre-norm with RMSNorm (Zhang & Sennrich, 2019) instead of LayerNorm. RMSNorm drops the mean-subtraction step and the β\beta shift, keeping only RMSNorm(z)=γz/1dizi2+ε\mathrm{RMSNorm}(z) = \gamma \cdot z / \sqrt{\tfrac{1}{d}\sum_i z_i^2 + \varepsilon}. It is cheaper and empirically matches or beats LayerNorm quality.
  • T5, BART, mBART, NLLB keep cross-attention (they are encoder-decoder models). T5 simplifies further: its LayerNorm has no bias term, matching the mean-free pattern Llama uses.
  • Decoder-only models (GPT-2, GPT-3, GPT-4, Llama family, Mistral, Qwen) drop cross-attention entirely. Their DecoderLayer has only two sublayers—masked self-attention and FFN—because there is no encoder to attend to. Section 6 will show how "encoder-decoder" and "decoder-only" architectures differ at the top level.
  • FFN variants: Llama uses SwiGLU (Shazeer, 2020), which replaces ReLU(xW1)W2\mathrm{ReLU}(xW_1) W_2 with (Swish(xWg)(xWu))Wd(\mathrm{Swish}(xW_{g}) \odot (xW_{u})) W_{d}—two parallel projections, a Swish gate, an elementwise product, and a down projection. It uses ~1.5× the parameters of vanilla ReLU-FFN but consistently improves perplexity.

The point is not to memorize every variant; it is to see that the three-sublayer residual structure is the stable core of the Transformer, and the choices we made above (post-norm, ReLU-FFN, LayerNorm) are the textbook starting point. Swap any of those three for its modern cousin and you have a current-generation LLM block.


Summary

  • A TransformerDecoderLayer composes three sublayers—masked self-attention, cross-attention, feed-forward—each wrapped in Dropout + residual Add + LayerNorm.
  • Input and output shapes are both [B,Ttgt,dmodel][B, T_{tgt}, d_{model}], which is what lets us stack NN copies in Section 5.
  • Post-norm (original paper) is easier to reason about; pre-norm (modern LLMs) is more stable at depth. Switching is a one-line code change.
  • Dropout appears three times: on softmax weights, on each sublayer output, and between the two FFN linears. Typical rate p=0.1p = 0.1; often p=0p = 0 in large pretraining runs.
  • The plain-Python forward pass and the PyTorch nn.Module implement exactly the same math—the PyTorch version just adds parameter tracking, GPU support, and autograd.

Exercises

  1. Pre-norm conversion (easy). Rewrite TransformerDecoderLayer.forward in pre-norm form. For each sublayer, apply LayerNorm to the input before the sublayer function and leave the residual path un-normalized: e.g. replace x1=LN(x+sa)x_1 = \mathrm{LN}(x + \mathrm{sa}) with x1=x+sa(LN(x))x_1 = x + \mathrm{sa}(\mathrm{LN}(x)). Run the same usage snippet and verify the output shape is still (1,3,8)(1, 3, 8).
  2. Mask polarity bug (medium). PyTorch's built-in nn.MultiheadAttention uses the opposite mask convention from our module (True = block). Write a small adapter that takes our "True = keep" mask and converts it so the forward produces identical numerical results when you swap in the built-in module. Test on the same tgt / memory / causal mask as the usage snippet.
  3. Parameter count (medium). Compute the total number of learnable parameters in one TransformerDecoderLayer for dmodel=512, H=8, dff=2048d_{model} = 512,\ H = 8,\ d_{ff} = 2048. Break the count down by sublayer (self-attn, cross-attn, FFN) and by projection (Q, K, V, O, W_1, W_2). Verify against sum(p.numel() for p in layer.parameters()).
  4. Cross-attention is the only source path (hard). Replace the cross-attention sublayer with an identity (skip it entirely) and run the layer on a translation dataset of your choice for 5 epochs. Compare BLEU against the full decoder. You should see catastrophic degradation—explain why in one paragraph by reference to the residual-stream picture from the math section.

Next Section Preview

One layer is not a decoder. In Section 5 — Complete Transformer Decoder we stack NN identical copies of the module we just built, add target-side token embeddings and positional encoding on top, and close the stack with a final output projection that turns the residual stream into vocabulary logits. We will also cover weight tying (reusing the embedding matrix for the output projection) and walk through how representational depth produces the feature hierarchy observed in BERTology-style probes.

Loading comments...