Chapter 8
20 min read
Section 44 of 75

Full Encoder-Decoder Transformer

Transformer Decoder

Introduction

Over the last seven chapters we built every component of the Transformer piece by piece. Chapter 7 finished the encoder — a stack of NN identical self-attention blocks that turns a source sentence into a bank of contextualized vectors we will call the memory. Section 5 of this chapter finished the decoder — a stack of NN identical blocks that does masked self-attention, cross-attention against the memory, and feed-forward projection.

This section glues them together. We wrap the two stacks, plus the source & target embedding tables and the final vocabulary projection, inside a single Transformer module whose forward(src,tgt)\text{forward}(src, tgt) runs end-to-end in one call. This is exactly the Vaswani et al. (2017) architecture — no extras, no missing pieces.

What changed in 2017: before the Transformer, sequence-to-sequence models were encoder-decoder RNNs. Even attention-augmented variants (Bahdanau 2014, Luong 2015) still stepped through the sequence one token at a time. The Transformer replaces recurrence with pure attention + FFN, so encoder and decoder are batched matrix multiplies — trivially parallelizable on GPUs and the reason we can now train trillion-parameter models.

Architecture Overview

At the highest level, the full Transformer is a three-stage pipeline. Source tokens flow through the encoder once; the resulting memory is then reused by every decoder step.

📝text
1Source token ids (B, T_src)
2                            |
3                  +---------v---------+
4                  | src_embed * sqrt  |
5                  |  + positional enc |
6                  +---------+---------+
7                            |
8                            v
9                  +---------------------+
10                  |   Encoder (N x)     |     self-attn + FFN
11                  |  - MHSA  - FFN      |     Add & Norm inside
12                  |  stack: L1 -> ... L6|
13                  +---------+-----------+
14                            |
15                            v          memory: (B, T_src, d_model)
16  Target ids (B, T_tgt) ----+-------------------------+
17         |                                            |
18         v                                            |
19+---------------------+                               |
20| tgt_embed * sqrt    |                               |
21|  + positional enc   |                               |
22+---------+-----------+                               |
23          |                                           |
24          v                                           v
25+-----------------------------+          (K, V come from memory)
26|     Decoder (N x)           |
27|  - Masked MHSA              |
28|  - Cross-Attn (Q from self, |
29|    K/V from memory)         |
30|  - FFN                      |
31|  stack: L1 -> ... L6        |
32+-------------+---------------+
33              |
34              v  hidden: (B, T_tgt, d_model)
35   +----------+-----------+
36   |  output_proj (Linear)|   d_model -> tgt_vocab
37   +----------+-----------+
38              |
39              v
40         logits (B, T_tgt, tgt_vocab)

Three things are worth staring at:

  1. The encoder runs exactly once per input sentence. During inference, its output (the memory) is computed and then reused across every decoder step.
  2. The decoder's cross-attention gets its K,VK, V from memory and its QQ from the decoder's own self-attention output. That is how target tokens "look at" the source.
  3. The final projection has shape (dmodel,Vtgt)(d_{\text{model}}, V_{\text{tgt}}), producing one logit per target vocab id at every position.

Forward Pass Equations

Written compactly, the end-to-end forward pass is just two lines. Let srcZB×Tsrc\text{src} \in \mathbb{Z}^{B \times T_{src}} be source token ids and tgtZB×Ttgt\text{tgt} \in \mathbb{Z}^{B \times T_{tgt}} target token ids. Then

memory=Encoder(src,src_mask)\text{memory} = \text{Encoder}(\text{src}, \text{src\_mask}), and logits=Decoder(tgt,memory,tgt_mask,src_mask)\text{logits} = \text{Decoder}(\text{tgt}, \text{memory}, \text{tgt\_mask}, \text{src\_mask}).

Here memoryRB×Tsrc×dmodel\text{memory} \in \mathbb{R}^{B \times T_{src} \times d_{model}} and logitsRB×Ttgt×Vtgt\text{logits} \in \mathbb{R}^{B \times T_{tgt} \times V_{tgt}}. The source mask src_mask\text{src\_mask} (a key-padding mask) is reused in both the encoder's self-attention and the decoder's cross-attention. The target mask tgt_mask\text{tgt\_mask} is the Ttgt×TtgtT_{tgt} \times T_{tgt} causal mask from §2.

Internally, Decoder expands toh0=Dropout(Embed(tgt)dmodel+PE)h_0 = \text{Dropout}(\text{Embed}(\text{tgt})\sqrt{d_{model}} + PE),h=DecoderLayer(h1,memory,tgt_mask,src_mask)h_\ell = \text{DecoderLayer}_\ell(h_{\ell-1}, \text{memory}, \text{tgt\_mask}, \text{src\_mask})for =1,,N\ell = 1, \ldots, N, and finallylogits=hNWout\text{logits} = h_N W_{\text{out}}^{\top}. The Encoder expands analogously (§5 of ch07).


Training vs Inference

Training: teacher forcing, one forward pass per batch

During training we have the full target sequence y=(y1,,yTtgt)y = (y_1, \ldots, y_{T_{tgt}}) in hand. We feed tgt_in=[sos,y1,,yTtgt1]\text{tgt\_in} = [\langle sos \rangle, y_1, \ldots, y_{T_{tgt}-1}] into the decoder and ask it to predict tgt_out=[y1,,yTtgt1,eos]\text{tgt\_out} = [y_1, \ldots, y_{T_{tgt}-1}, \langle eos \rangle]. The causal mask ensures that position tt cannot peek at positions >t> t. One forward pass computes logits for ALL target positions simultaneously, and the loss is the mean cross-entropy between logits and tgt_out\text{tgt\_out}.

Why teacher forcing works: the causal mask turns what is morally TtgtT_{tgt} separate autoregressive problems into one parallel matrix computation. Every target position sees only its valid left context, exactly as it would during inference — but training throughput is multiplied by TtgtT_{tgt}.

Inference: autoregressive loop

At test time we don't have yy. We start with sos\langle sos \rangle and generate one token per step until we hit eos\langle eos \rangle or t=Tmaxt = T_{\max}. Each step runs the FULL decoder on the partial sequence, takes the last position's logits, picks a next token, and appends it.

This naive loop is O(Ttgt3)O(T_{tgt}^3) in work and O(Ttgt2)O(T_{tgt}^2) in peak memory, because step tt recomputes attention for all previous positions. Chapter 9 introduces beam search, temperature / top-k / top-p sampling, and — critically — KV-caching, which stores the K,VK, V of previous positions so step tt only pays O(Ttgt)O(T_{tgt}) work.


Shared vs Separate Vocabularies

The Transformer has two embedding tables: EsrcRVsrc×dmodelE_{src} \in \mathbb{R}^{V_{src} \times d_{model}} and EtgtRVtgt×dmodelE_{tgt} \in \mathbb{R}^{V_{tgt} \times d_{model}}, plus an output projection WoutRVtgt×dmodelW_{\text{out}} \in \mathbb{R}^{V_{tgt} \times d_{model}}. You have two design axes:

  • Shared source & target embedding (Esrc=EtgtE_{src} = E_{tgt}): sensible when the two languages share a script and you train a joint BPE vocabulary over both corpora. Common for En↔De, En↔Fr in the original paper, and for code-completion models where input and output come from the same token space.
  • Weight tying (Wout=EtgtW_{\text{out}} = E_{tgt}): ties the output projection to the target-embedding table (Press & Wolf 2017). Saves VtgtdmodelV_{tgt} \cdot d_{model} parameters and usually improves perplexity, because both matrices encode the same "which direction in hidden space represents token vv?" relation.
  • Separate vocabularies: required when scripts or tokenizers differ (En↔Zh, En↔Ar). Here a shared BPE would waste capacity on tokens that never co-occur.

For the Multi30k translation project in chapter 13 we use separate source (de) and target (en) BPE vocabularies — Vsrc7000V_{src} \approx 7000 and Vtgt5500V_{tgt} \approx 5500. That is the setting the code below targets.


Plain-Python End-to-End

Before introducing the Transformer module, let's watch the shapes move by hand. We use the shared chapter config — B=1,Tsrc=4,Ttgt=3,dmodel=8,H=2,dff=16,N=2B=1, T_{src}=4, T_{tgt}=3, d_{model}=8, H=2, d_{ff}=16, N=2, Vsrc=12,Vtgt=10V_{src}=12, V_{tgt}=10, and torch.manual_seed(0)\text{torch.manual\_seed}(0). The actual printed values below were computed by running this code.

Plain-Python End-to-End Forward Pass (toy shapes)
🐍toy_transformer_forward.py
1import torch

PyTorch's top-level package. Gives us tensors, autograd, random-number generation, and the neural-network toolbox we build on top of.

EXECUTION STATE
torch = Core tensor library. Provides torch.Tensor, torch.randint, torch.zeros, torch.arange, torch.triu, torch.manual_seed — every primitive we need below.
2import torch.nn as nn

torch.nn holds the building blocks: nn.Embedding, nn.Linear, nn.TransformerEncoderLayer, nn.TransformerDecoderLayer. Every learnable piece of the Transformer we assemble lives here.

EXECUTION STATE
nn = Short alias for torch.nn. Used as nn.Embedding(vocab, d_model), nn.Linear(in, out), etc.
3import torch.nn.functional as F

Functional (stateless) API. We don't call F explicitly in this snippet but it's the standard import used when we call e.g. F.softmax(logits, dim=-1) after this cell.

EXECUTION STATE
F = Stateless counterparts of nn modules. F.softmax, F.relu, F.log_softmax — no parameters to own, pure functions.
4import math

Python standard-library math module. We need math.sqrt(d_model) for embedding scaling and math.log(10000.0) for sinusoidal positional encoding.

EXECUTION STATE
math.sqrt(8) = 2.8284... used to scale embeddings per the original paper
7torch.manual_seed(0)

Sets the global RNG seed so all random draws below (token ids, embedding weights, layer weights) are reproducible. Same seed used throughout §4 and §5 so readers can compare values across sections.

EXECUTION STATE
📚 torch.manual_seed(seed) = Seeds PyTorch's CPU RNG. Affects torch.randint, nn.Embedding weight init, nn.Linear weight init — everything that draws randomness later in the program.
⬇ arg: seed = 0 = Chosen to match §4/§5 so the same memory[0,0,:4] values come out here.
8B, T_src, T_tgt = 1, 4, 3

Shared tiny shapes. Batch=1, source length 4 tokens, target length 3 tokens. Small enough that every intermediate tensor is inspectable.

EXECUTION STATE
B = 1 = Batch size. Only one sentence pair, so every shape starts with 1.
T_src = 4 = Source sequence length. Represents e.g. a 4-token German sentence after tokenization.
T_tgt = 3 = Target sequence length. Represents e.g. a 3-token English translation.
9d_model, H, d_ff = 8, 2, 16

Toy model dims. d_model=8 is the hidden size, H=2 heads (so d_k=4 per head), d_ff=16 is the FFN inner dim. Real models use 512/8/2048 (ch13).

EXECUTION STATE
d_model = 8 = Width of every token vector throughout the stack. Input, output, and memory are all [B, T, 8].
H = 2 = Number of attention heads. d_k = d_model / H = 4 per head.
d_ff = 16 = Inner width of the FFN sublayer. Typically 4 * d_model. For d_model=8 we use 16.
10N = 2

Number of stacked layers in BOTH encoder and decoder. Vaswani 2017 used N=6; we use 2 so the walkthrough is cheap.

EXECUTION STATE
N = 2 = 2 encoder layers + 2 decoder layers, stacked identically.
11src_vocab, tgt_vocab = 12, 10

Two SEPARATE vocabularies: 12 source types (de) and 10 target types (en). That is why we need two embedding tables — ids from different vocabularies are not comparable.

EXECUTION STATE
src_vocab = 12 = Size of the source-side vocabulary. Source token ids live in [0, 12).
tgt_vocab = 10 = Size of the target-side vocabulary. Target token ids live in [0, 10). Output logits will have 10 classes.
14src = torch.randint(1, src_vocab, (B, T_src))

Draws a random source sentence of 4 token ids in [1, 12). We skip id 0 because that is conventionally a <pad> token.

EXECUTION STATE
📚 torch.randint(low, high, size) = Draws integers uniformly from [low, high). Returns a tensor of given size. Example: torch.randint(1, 5, (2,)) -> [2, 4].
⬇ arg: low = 1 = Inclusive lower bound. We skip 0 (reserved for <pad>).
⬇ arg: high = 12 = Exclusive upper bound = src_vocab.
⬇ arg: size = (1, 4) = Output shape (B, T_src).
⬆ result: src =
tensor([[11, 8, 2, 4]]) shape (1, 4)
15tgt = torch.randint(1, tgt_vocab, (B, T_tgt))

Random target sentence of 3 ids. In real training these come from the dataloader; here random is fine because we only want to demonstrate the forward pass plumbing.

EXECUTION STATE
⬆ result: tgt =
tensor([[8, 7, 8]]) shape (1, 3)
18src_embed = nn.Embedding(src_vocab, d_model)

Source-side lookup table. Row i of its weight matrix is the 8-dim vector for source token id i. Total parameters: src_vocab * d_model = 12 * 8 = 96.

EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim) = Stores a learnable (num_embeddings, embedding_dim) matrix. Calling it with a LongTensor of ids selects the corresponding rows. Faster than doing a one-hot-then-matmul.
⬇ arg: num_embeddings = 12 = How many rows the table has — one row per src vocab id.
⬇ arg: embedding_dim = 8 = Width of each row = d_model. Matches the rest of the stack so shapes line up.
⬆ src_embed.weight = (12, 8) random Gaussian init — 96 learnable parameters.
19tgt_embed = nn.Embedding(tgt_vocab, d_model)

Target-side lookup table. SEPARATE from src_embed because the target vocabulary is different (en vs de). Total parameters: 10 * 8 = 80.

EXECUTION STATE
⬆ tgt_embed.weight = (10, 8) — 80 learnable parameters, independent from src_embed.
22def make_pe(max_len, d)

Builds the fixed sinusoidal positional-encoding table from Vaswani 2017. Returns a (max_len, d) matrix where row p is the positional signature for position p.

EXECUTION STATE
⬇ input: max_len = 16 = How many positions to precompute. We only use the first T_src or T_tgt rows.
⬇ input: d = 8 = Same as d_model so PE can be added directly to embeddings.
⬆ returns = torch.Tensor (16, 8) — fixed, not learnable.
23pe = torch.zeros(max_len, d)

Allocate a (16, 8) zero matrix we will fill with sin/cos values.

EXECUTION STATE
pe = (16, 8) — all zeros initially.
24pos = torch.arange(0, max_len).float().unsqueeze(1)

Column vector [0,1,2,...,15]^T of floats. .unsqueeze(1) turns shape (16,) into (16, 1) so we can broadcast against div of shape (4,).

EXECUTION STATE
📚 .unsqueeze(1) = Inserts a new axis of size 1 at position 1. Shape (16,) -> (16, 1). Enables broadcasting: (16, 1) * (4,) -> (16, 4).
pos = (16, 1) column vector 0..15
25div = exp(arange(0, d, 2) * (-log(10000)/d))

Frequency terms. For d=8 we need 4 frequencies (even dims 0,2,4,6). Higher dims -> lower frequencies. This gives us positions encoded at multiple scales.

EXECUTION STATE
torch.arange(0, 8, 2) = tensor([0, 2, 4, 6]) — even indices only
-math.log(10000.0) / 8 = -1.1513 ...
div = (4,) tensor of decreasing frequency scales
26pe[:, 0::2] = torch.sin(pos * div)

Fills even columns with sin(pos * freq). pos(16,1) * div(4,) broadcasts to (16, 4), matching pe[:, 0::2] which is also (16, 4).

EXECUTION STATE
0::2 = Python slice: start=0, stop=end, step=2 — every EVEN index.
pe[:, 0::2].shape = (16, 4) — the 4 even columns of pe.
27pe[:, 1::2] = torch.cos(pos * div)

Fills odd columns with cos(pos * freq). So even dims use sin, odd dims use cos — this is the classic sinusoidal scheme.

28return pe

Hand back the (16, 8) position table.

EXECUTION STATE
⬆ return: pe = (16, 8) fixed sinusoidal positional encoding
30pe = make_pe(16, d_model)

Build the table once; we slice it to T_src or T_tgt rows as needed.

33src_x = src_embed(src) * sqrt(d_model) + pe[:T_src].unsqueeze(0)

Look up source embeddings, scale by sqrt(d_model)=2.828 (Vaswani trick — keeps embedding magnitude comparable to PE), then add positional encoding for positions 0..T_src-1.

EXECUTION STATE
src_embed(src) = Shape (1, 4, 8). Row i is src_embed.weight[src[0,i]].
* math.sqrt(d_model) = Scalar multiply by 2.8284. Scales embeddings up so adding PE (whose values are in [-1, 1]) doesn't dominate.
pe[:T_src] = First 4 rows of pe -> shape (4, 8).
.unsqueeze(0) = (4, 8) -> (1, 4, 8) so it broadcasts against (1, 4, 8) embeddings.
⬆ src_x =
(1, 4, 8). First row: src_x[0,0,:4] = [2.008, -3.342, -1.167, 3.733] (with seed=0).
35enc_layer = nn.TransformerEncoderLayer(...)

Builds ONE encoder layer (self-attn + FFN + 2 LayerNorms). We use PyTorch's built-in for this demo; ch07 showed how to write it from scratch.

EXECUTION STATE
📚 nn.TransformerEncoderLayer = Implements: x = LayerNorm(x + SelfAttn(x)); x = LayerNorm(x + FFN(x)). Post-norm by default.
⬇ arg: d_model = 8 = Hidden width.
⬇ arg: nhead = 2 = Number of attention heads.
⬇ arg: dim_feedforward = 16 = FFN inner dim = d_ff.
⬇ arg: batch_first = True = Tensors are (B, T, d_model) instead of the older (T, B, d_model). We match this everywhere.
⬇ arg: activation = 'relu' = FFN nonlinearity. Vaswani 2017 used ReLU; modern LLMs use GELU/SwiGLU.
38encoder = nn.TransformerEncoder(enc_layer, num_layers=N)

Clones enc_layer N=2 times with independent weights and stacks them. Output of layer k is the input to layer k+1.

EXECUTION STATE
📚 nn.TransformerEncoder(layer, num_layers) = Takes one template layer and a count; produces a stack via deepcopy so each layer has its own parameters.
39memory = encoder(src_x)

Runs src_x through the N-layer encoder. Output shape: (B, T_src, d_model) = (1, 4, 8). This is the 'memory' the decoder will cross-attend to.

EXECUTION STATE
⬆ memory =
(1, 4, 8). memory[0,0,:4] = [0.730, 0.329, -0.952, 1.499] (seed=0).
→ semantic meaning = Each of the 4 rows is a context-enriched representation of one source token. Row 0 'knows' about tokens 1,2,3 via self-attention.
42tgt_x = tgt_embed(tgt) * sqrt(d_model) + pe[:T_tgt].unsqueeze(0)

Same treatment for the target side. Separate embedding table, same PE table. Shape (1, 3, 8).

EXECUTION STATE
⬆ tgt_x =
(1, 3, 8) — target-side token+position vectors.
44dec_layer = nn.TransformerDecoderLayer(...)

Decoder layer = masked self-attn + cross-attn + FFN + 3 LayerNorms. Same hyperparameters as encoder layer.

EXECUTION STATE
📚 nn.TransformerDecoderLayer = x = LN(x + MaskedSelfAttn(x)); x = LN(x + CrossAttn(x, memory)); x = LN(x + FFN(x)).
47decoder = nn.TransformerDecoder(dec_layer, num_layers=N)

Stack of 2 independent decoder layers.

49causal = torch.triu(ones(T_tgt, T_tgt) * -inf, diagonal=1)

Builds the (3, 3) additive causal mask used for target self-attention. Entries strictly above the diagonal become -inf; diagonal and below stay 0.

EXECUTION STATE
📚 torch.triu(input, diagonal=k) = Returns the upper-triangular part of input, zeroing everything below diagonal k. diagonal=1 skips the main diagonal itself.
⬇ arg: input = ones(3,3)*-inf = A 3x3 matrix full of -inf.
⬇ arg: diagonal = 1 = Keep values strictly above the main diagonal; everything on and below becomes 0.
⬆ causal (3, 3) =
[[0, -inf, -inf],
 [0,   0 , -inf],
 [0,   0 ,   0 ]]
50h = decoder(tgt_x, memory, tgt_mask=causal)

Run target through the decoder stack, using memory as cross-attention K/V and causal as the self-attention mask.

EXECUTION STATE
⬇ arg: tgt_x = (1, 3, 8) — target embeddings+PE (the queries).
⬇ arg: memory = (1, 4, 8) — encoder output, used as K and V in the cross-attention sublayers.
⬇ arg: tgt_mask = causal = (3, 3) additive mask. Prevents position t from attending to positions > t.
⬆ h =
(1, 3, 8). h[0,0,:4] = [1.348, -0.131, -1.344, -0.364] (seed=0).
53W_out = nn.Linear(d_model, tgt_vocab, bias=False)

Output projection from hidden space (8) to target vocabulary (10). No bias — matches the common 'weight tying' variant where this W can share weights with tgt_embed (Press & Wolf 2017).

EXECUTION STATE
📚 nn.Linear(in, out, bias=False) = y = x @ W.T, no bias term. Parameters = in * out = 8 * 10 = 80.
⬇ arg: in_features = 8 = Matches d_model.
⬇ arg: out_features = 10 = Matches tgt_vocab. One logit per vocab id.
⬇ arg: bias = False = Standard for vocabulary heads (saves 10 parameters, no harm).
54logits = W_out(h)

Per-position projection: (1, 3, 8) @ (8, 10) -> (1, 3, 10). logits[0, t, v] is the score the model assigns to vocab id v at target position t.

EXECUTION STATE
⬆ logits =
(1, 3, 10). logits[0,0] = [0.361, -0.603, 0.058, 0.903, 0.147, -0.001, -0.121, -0.979, 0.720, 0.566].
→ next step = Training: compare to tgt with F.cross_entropy. Inference: take argmax over the last dim to pick a token (ch09 covers beam search / sampling).
56print('memory[0,0,:4] =', ...)

Sanity print. Confirms encoder produced finite values.

EXECUTION STATE
output = memory[0,0,:4] = [0.7302, 0.3295, -0.9525, 1.4986]
57print('logits[0,0] =', ...)

Sanity print for logits of the first target position. 10 numbers = scores over the 10-way target vocabulary.

EXECUTION STATE
output = logits[0,0] = [0.361, -0.603, 0.058, 0.903, 0.147, -0.001, -0.121, -0.979, 0.720, 0.566]
→ argmax = argmax = 3 -> predicted vocab id 3 for the first target step.
23 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6# ---- Shared tiny config (see §4 of this chapter) ----
7torch.manual_seed(0)
8B, T_src, T_tgt = 1, 4, 3
9d_model, H, d_ff = 8, 2, 16
10N = 2
11src_vocab, tgt_vocab = 12, 10
12
13# Random source / target token ids
14src = torch.randint(1, src_vocab, (B, T_src))   # [[11, 8, 2, 4]]
15tgt = torch.randint(1, tgt_vocab, (B, T_tgt))   # [[8, 7, 8]]
16
17# Separate embedding tables (de vs en vocabularies)
18src_embed = nn.Embedding(src_vocab, d_model)
19tgt_embed = nn.Embedding(tgt_vocab, d_model)
20
21# Sinusoidal positional encoding (fixed, not learned)
22def make_pe(max_len, d):
23    pe = torch.zeros(max_len, d)
24    pos = torch.arange(0, max_len).float().unsqueeze(1)
25    div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
26    pe[:, 0::2] = torch.sin(pos * div)
27    pe[:, 1::2] = torch.cos(pos * div)
28    return pe
29
30pe = make_pe(16, d_model)
31
32# Stage 1: source side -> encoder
33src_x = src_embed(src) * math.sqrt(d_model) + pe[:T_src].unsqueeze(0)
34
35enc_layer = nn.TransformerEncoderLayer(
36    d_model=d_model, nhead=H, dim_feedforward=d_ff,
37    batch_first=True, activation='relu')
38encoder = nn.TransformerEncoder(enc_layer, num_layers=N)
39memory = encoder(src_x)                         # [1, 4, 8]
40
41# Stage 2: target side -> decoder (queries memory via cross-attn)
42tgt_x = tgt_embed(tgt) * math.sqrt(d_model) + pe[:T_tgt].unsqueeze(0)
43
44dec_layer = nn.TransformerDecoderLayer(
45    d_model=d_model, nhead=H, dim_feedforward=d_ff,
46    batch_first=True, activation='relu')
47decoder = nn.TransformerDecoder(dec_layer, num_layers=N)
48
49causal = torch.triu(torch.ones(T_tgt, T_tgt) * float('-inf'), diagonal=1)
50h = decoder(tgt_x, memory, tgt_mask=causal)     # [1, 3, 8]
51
52# Stage 3: project decoder hidden states to vocabulary logits
53W_out = nn.Linear(d_model, tgt_vocab, bias=False)
54logits = W_out(h)                               # [1, 3, 10]
55
56print("memory[0,0,:4] =", memory[0, 0, :4].tolist())
57print("logits[0,0]    =", [round(x, 3) for x in logits[0, 0].tolist()])

Running this prints:

📝text
1memory[0,0,:4] = [0.7302, 0.3295, -0.9525, 1.4986]
2logits[0,0]    = [0.361, -0.603, 0.058, 0.903, 0.147, -0.001, -0.121, -0.979, 0.720, 0.566]
3argmax logits[0,0] = 3     # predicted vocab id for the first target position

The shape chain is exactly what the architecture diagram promised: source ids (1,4)(1, 4) \to memory (1,4,8)(1, 4, 8) \to decoder hidden (1,3,8)(1, 3, 8) \to logits (1,3,10)(1, 3, 10). Every sublayer we wrote from scratch in the previous chapters is hidden inside those two stack calls.


PyTorch Implementation

The Transformer module

Now the real thing — a single nn.Module\text{nn.Module} that owns every learnable piece.

class Transformer(nn.Module) — wire encoder + decoder + embeddings
🐍transformer.py
1import math

math.sqrt for embedding scaling and math.log for sinusoidal frequencies.

2import torch

Core tensor library. Provides torch.Tensor, torch.arange, torch.zeros — all the primitives used in _build_pe.

3import torch.nn as nn

nn.Module base class, nn.Embedding, nn.Linear, nn.Dropout, nn.TransformerEncoder{,Layer}, nn.TransformerDecoder{,Layer}.

4from typing import Optional

Lets us type mask arguments as Optional[torch.Tensor] — either a tensor or None.

7class Transformer(nn.Module)

Our module subclasses nn.Module so PyTorch can auto-track parameters, move them to GPU, save/load state dicts, and plug into DataParallel etc.

EXECUTION STATE
📚 nn.Module = Base class for neural net modules. Subclasses implement __init__ (register sublayers) and forward (compute). Parameters are discovered automatically via setattr.
18def __init__(self, src_vocab, tgt_vocab, d_model=512, ...)

Constructor signature. Accepts both vocabulary sizes (two separate vocabs) and the standard Vaswani hyperparameters, with defaults matching the original paper.

EXECUTION STATE
⬇ src_vocab (int) = Size of source vocab (e.g. 7000 for Multi30k de BPE). Sets number of rows in self.src_embed.weight.
⬇ tgt_vocab (int) = Size of target vocab (e.g. 5500 for Multi30k en BPE). Sets rows in tgt_embed AND out_features of output_proj.
⬇ d_model = 512 = Hidden dim throughout. Every token is a 512-dim vector.
⬇ num_heads = 8 = Attention heads per layer. d_k = 512/8 = 64 per head.
⬇ num_layers = 6 = Stack depth — applied SYMMETRICALLY to encoder and decoder in this implementation.
⬇ d_ff = 2048 = FFN inner dim = 4 * d_model (Vaswani default).
⬇ max_len = 5000 = Precompute PE for up to 5000 positions. Caps the maximum sentence length the model can handle without rebuilding.
⬇ dropout = 0.1 = Applied after embedding+PE and inside each encoder/decoder sublayer. 0.1 is Vaswani default; 0.3 for lower-resource setups like Multi30k.
28super().__init__()

Runs nn.Module's __init__ so parameter tracking, hooks, and children registry get set up BEFORE we assign submodules.

29self.d_model = d_model

Stash d_model so encode() and decode() can scale embeddings by sqrt(d_model).

32self.src_embed = nn.Embedding(src_vocab, d_model)

Source-side lookup table. src_vocab * d_model parameters (e.g. 7000 * 512 = 3.584M for Multi30k).

EXECUTION STATE
⬇ arg: num_embeddings = src_vocab = One row per source token id.
⬇ arg: embedding_dim = d_model = Row width = 512 so shapes line up with the rest of the stack.
33self.tgt_embed = nn.Embedding(tgt_vocab, d_model)

Target-side lookup table. Separate from src_embed — the two languages have different vocabularies.

36self.register_buffer('pe', self._build_pe(max_len, d_model))

Stores PE as a BUFFER (not a Parameter). Buffers move with .to(device) and get saved in state_dict, but they are NOT updated by the optimizer — exactly what we want for fixed sinusoidal PE.

EXECUTION STATE
📚 register_buffer(name, tensor) = Registers a non-trainable tensor with the module. Accessible as self.pe afterwards. Appears in state_dict by default.
⬇ arg: name = 'pe' = Attribute name; we'll read it as self.pe later.
⬇ arg: tensor = (1, max_len, d_model) precomputed sinusoidal PE.
37self.dropout = nn.Dropout(dropout)

Applied after embedding+PE addition on both encoder and decoder inputs (Vaswani 2017, §5.4).

EXECUTION STATE
📚 nn.Dropout(p) = In training mode, sets each element to 0 with probability p and scales surviving ones by 1/(1-p). In eval mode (model.eval()), it's a no-op.
40enc_layer = nn.TransformerEncoderLayer(d_model, num_heads, d_ff, dropout, batch_first=True)

Builds ONE encoder block: masked self-attn + FFN wrapped in Add&Norm. We reuse PyTorch's well-tested implementation here; ch07 built the same thing from scratch.

EXECUTION STATE
⬇ arg: d_model = 512 = Hidden dim.
⬇ arg: nhead = num_heads = 8 = Attention heads.
⬇ arg: dim_feedforward = d_ff = 2048 = FFN inner dim.
⬇ arg: dropout = 0.1 = Dropout inside the sublayer.
⬇ arg: batch_first = True = Expect tensors of shape (B, T, d_model). Without this PyTorch defaults to (T, B, d_model).
42self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

Clones enc_layer num_layers times (each gets its own weights via deepcopy) and wires them in series.

EXECUTION STATE
📚 nn.TransformerEncoder(layer, num_layers) = Deep-copies the template layer num_layers times; forward runs them in order.
45dec_layer = nn.TransformerDecoderLayer(d_model, num_heads, d_ff, dropout, batch_first=True)

One decoder block: masked self-attn + cross-attn (queries memory) + FFN, each wrapped in Add&Norm.

47self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)

Stack of num_layers decoder blocks, independent weights.

50self.output_proj = nn.Linear(d_model, tgt_vocab, bias=False)

Final projection from d_model to tgt_vocab classes. No bias — keeps it tieable with tgt_embed if you want weight-tying.

EXECUTION STATE
⬇ arg: in_features = d_model = 512 = Matches encoder/decoder hidden size.
⬇ arg: out_features = tgt_vocab = One logit per target vocab id.
⬇ arg: bias = False = Convention for vocab heads; also a prerequisite for weight-tying output_proj.weight = tgt_embed.weight.
52@staticmethod

Marks _build_pe as not needing self — it's a pure function. Lets us build PE before any instance state exists if we want.

53def _build_pe(max_len, d_model)

Same sinusoidal table construction as in the plain-Python walkthrough above. Returns shape (1, max_len, d_model) so it broadcasts against (B, T, d_model) directly.

EXECUTION STATE
⬇ input: max_len = How many positions to precompute.
⬇ input: d_model = Must match the hidden dim used throughout.
⬆ returns = torch.Tensor (1, max_len, d_model).
54pe = torch.zeros(max_len, d_model)

(5000, 512) zero matrix.

55pos = torch.arange(0, max_len).float().unsqueeze(1)

Column vector of positions 0..max_len-1, shape (max_len, 1). .unsqueeze(1) turns (5000,) into (5000, 1) for broadcasting.

56div = exp(arange(0, d_model, 2) * -log(10000)/d_model)

256 frequencies for d_model=512 (every other dim). Higher dims -> lower frequencies, giving multi-scale position signatures.

58pe[:, 0::2] = torch.sin(pos * div)

Even columns filled with sin(pos * freq).

59pe[:, 1::2] = torch.cos(pos * div)

Odd columns filled with cos(pos * freq).

60return pe.unsqueeze(0)

.unsqueeze(0) turns (max_len, d_model) into (1, max_len, d_model) so we can add it directly to (B, T, d_model) embeddings.

EXECUTION STATE
⬆ return = (1, 5000, 512) fixed PE, stored as a buffer.
62def encode(self, src, src_mask=None)

Runs the source-side pipeline: embed -> scale -> add PE -> dropout -> N encoder layers. Output is the 'memory' the decoder will cross-attend to.

EXECUTION STATE
⬇ input: src = LongTensor (B, T_src) of source token ids. Example: tensor([[11, 8, 2, 4]]).
⬇ input: src_mask = Optional BoolTensor (B, T_src). True means PAD -> ignore in attention. None = no masking.
⬆ returns = torch.Tensor (B, T_src, d_model) — encoder output = memory.
64T = src.size(1)

Dynamic source length. We slice self.pe to this many positions instead of using all max_len.

65x = self.src_embed(src) * sqrt(d_model) + self.pe[:, :T]

Lookup + Vaswani scale + add PE. Shapes: (B, T, d_model) + (1, T, d_model) broadcasts to (B, T, d_model).

EXECUTION STATE
self.src_embed(src) = (B, T, d_model) — per-token embedding rows.
* math.sqrt(self.d_model) = For d_model=512 this is sqrt(512) = 22.627. Scales embeddings up so PE doesn't dominate after addition.
self.pe[:, :T] = (1, T, d_model) — only the first T positions.
66x = self.dropout(x)

Per-element dropout on embedding+PE. Acts as regularization at the very input layer.

67return self.encoder(x, src_key_padding_mask=src_mask)

Runs all N encoder layers. src_key_padding_mask (B, T_src) tells each self-attention sublayer which source positions are padding and should receive -inf scores.

EXECUTION STATE
📚 src_key_padding_mask = Bool mask of shape (B, T_src). True -> position is PAD -> zero out its attention contribution. This is PyTorch's convention (True = ignore).
⬆ return = (B, T_src, d_model) encoder output a.k.a. memory.
69def decode(self, tgt, memory, tgt_mask=None, memory_mask=None)

Mirror of encode() for the target side, plus cross-attention over memory.

EXECUTION STATE
⬇ input: tgt = LongTensor (B, T_tgt) of target token ids.
⬇ input: memory = (B, T_src, d_model) encoder output; used as K and V in cross-attention.
⬇ input: tgt_mask = (T_tgt, T_tgt) additive causal mask — prevents attending to future positions during training.
⬇ input: memory_mask = (B, T_src) key-padding mask for cross-attention so PAD source positions are ignored.
⬆ returns = (B, T_tgt, d_model) decoder output.
71T = tgt.size(1)

Dynamic target length.

72x = self.tgt_embed(tgt) * sqrt(d_model) + self.pe[:, :T]

Target-side embedding + position. Same PE TABLE as the source side — PyTorch / Vaswani share it; the tokens themselves are different.

73x = self.dropout(x)

Dropout on target-side embedding+PE.

74return self.decoder(x, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask)

Runs all N decoder layers. tgt_mask is the (T_tgt, T_tgt) causal mask for self-attn; memory_key_padding_mask masks source PAD positions in cross-attn.

EXECUTION STATE
📚 tgt_mask vs memory_key_padding_mask = tgt_mask: additive, shape (T_tgt, T_tgt), masks FUTURE target positions (causal). memory_key_padding_mask: bool, shape (B, T_src), masks PAD source positions. Different roles, different shapes.
79def forward(self, src, tgt, src_mask=None, tgt_mask=None)

Single-call end-to-end forward: source -> memory -> logits. Used during training (teacher forcing) and during one step of inference.

EXECUTION STATE
⬇ input: src = (B, T_src) source token ids.
⬇ input: tgt = (B, T_tgt) target token ids. Training: full shifted-right target. Inference: partial sequence generated so far.
⬇ input: src_mask = Optional (B, T_src) key-padding mask; reused by BOTH the encoder and the decoder's cross-attention.
⬇ input: tgt_mask = Optional (T_tgt, T_tgt) causal mask. Supplied during training; during greedy generation we build it dynamically.
⬆ returns = (B, T_tgt, tgt_vocab) logits tensor.
85memory = self.encode(src, src_mask=src_mask)

Stage 1: source -> memory. Runs exactly once per forward call (and, during inference, exactly once per sentence).

86hidden = self.decode(tgt, memory, tgt_mask=tgt_mask, memory_mask=src_mask)

Stage 2: target + memory -> hidden. src_mask is fed as memory_mask here because the decoder's cross-attention treats memory as keys.

89logits = self.output_proj(hidden)

Stage 3: project to vocab logits. Shape (B, T_tgt, tgt_vocab). These get passed to F.cross_entropy during training.

EXECUTION STATE
⬆ logits = (B, T_tgt, tgt_vocab). argmax over last dim = greedy predicted token at each position.
90return logits

Returns logits. NOTE: we do not apply softmax here — F.cross_entropy expects raw logits and applies log_softmax internally for numerical stability.

54 lines without explanation
1import math
2import torch
3import torch.nn as nn
4from typing import Optional
5
6
7class Transformer(nn.Module):
8    """Full Vaswani 2017 encoder-decoder Transformer.
9
10    Wires:
11      - source token embedding + positional encoding
12      - N encoder layers  -> memory
13      - target token embedding + positional encoding
14      - N decoder layers (queries memory)       -> hidden
15      - output projection to target vocab logits
16    """
17
18    def __init__(
19        self,
20        src_vocab: int,
21        tgt_vocab: int,
22        d_model: int = 512,
23        num_heads: int = 8,
24        num_layers: int = 6,
25        d_ff: int = 2048,
26        max_len: int = 5000,
27        dropout: float = 0.1,
28    ):
29        super().__init__()
30        self.d_model = d_model
31
32        # Separate embedding tables: src_vocab != tgt_vocab (translation case)
33        self.src_embed = nn.Embedding(src_vocab, d_model)
34        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
35
36        # Fixed sinusoidal positional encoding (shared across source & target)
37        self.register_buffer("pe", self._build_pe(max_len, d_model))
38        self.dropout = nn.Dropout(dropout)
39
40        # Encoder (reuses the block from ch07)
41        enc_layer = nn.TransformerEncoderLayer(
42            d_model, num_heads, d_ff, dropout, batch_first=True)
43        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
44
45        # Decoder (reuses the block from §4/§5)
46        dec_layer = nn.TransformerDecoderLayer(
47            d_model, num_heads, d_ff, dropout, batch_first=True)
48        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
49
50        # Output projection to target vocabulary
51        self.output_proj = nn.Linear(d_model, tgt_vocab, bias=False)
52
53    @staticmethod
54    def _build_pe(max_len: int, d_model: int) -> torch.Tensor:
55        pe = torch.zeros(max_len, d_model)
56        pos = torch.arange(0, max_len).float().unsqueeze(1)
57        div = torch.exp(torch.arange(0, d_model, 2).float()
58                        * (-math.log(10000.0) / d_model))
59        pe[:, 0::2] = torch.sin(pos * div)
60        pe[:, 1::2] = torch.cos(pos * div)
61        return pe.unsqueeze(0)  # (1, max_len, d_model)
62
63    def encode(self, src: torch.Tensor,
64               src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
65        T = src.size(1)
66        x = self.src_embed(src) * math.sqrt(self.d_model) + self.pe[:, :T]
67        x = self.dropout(x)
68        return self.encoder(x, src_key_padding_mask=src_mask)
69
70    def decode(self, tgt: torch.Tensor, memory: torch.Tensor,
71               tgt_mask: Optional[torch.Tensor] = None,
72               memory_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
73        T = tgt.size(1)
74        x = self.tgt_embed(tgt) * math.sqrt(self.d_model) + self.pe[:, :T]
75        x = self.dropout(x)
76        return self.decoder(
77            x, memory,
78            tgt_mask=tgt_mask,
79            memory_key_padding_mask=memory_mask,
80        )
81
82    def forward(
83        self,
84        src: torch.Tensor,
85        tgt: torch.Tensor,
86        src_mask: Optional[torch.Tensor] = None,
87        tgt_mask: Optional[torch.Tensor] = None,
88    ) -> torch.Tensor:
89        memory = self.encode(src, src_mask=src_mask)
90        hidden = self.decode(tgt, memory,
91                             tgt_mask=tgt_mask,
92                             memory_mask=src_mask)
93        logits = self.output_proj(hidden)  # (B, T_tgt, tgt_vocab)
94        return logits

A couple of design choices worth calling out:

  • PE as a buffer, not a Parameter. We want PE to travel with the module via .to(device)\text{.to(device)} and state_dict\text{state\_dict}, but we do NOT want the optimizer to update it.
  • src_mask is fed twice: once as src_key_padding_mask\text{src\_key\_padding\_mask} for the encoder, again as memory_key_padding_mask\text{memory\_key\_padding\_mask} for the decoder's cross-attention. Both places need to know which source positions are padding.
  • No softmax in forward. We return raw logits so that F.cross_entropyF.\text{cross\_entropy} can call log-softmax\log\text{-softmax} internally with better numerics.

Greedy generation

For the simplest possible inference, we pair forward\text{forward} with a generate\text{generate} method that loops over steps. This belongs on the same module so it can reuse self.encode\text{self.encode} / self.decode\text{self.decode}.

Transformer.generate — greedy autoregressive decoding
🐍transformer_generate.py
1@torch.no_grad()

Decorator that disables autograd for this method. During inference we don't need gradients, and skipping them saves memory and speed.

EXECUTION STATE
📚 torch.no_grad = Context manager / decorator that sets requires_grad=False for all ops inside. Same effect as wrapping the body in `with torch.no_grad(): ...`.
2def generate(self, src, max_len, sos_id, eos_id, src_mask=None)

Greedy decoding loop. At each step: run forward, take argmax of the LAST position's logits, append, repeat until max_len or eos.

EXECUTION STATE
⬇ input: src = (B, T_src) source token ids.
⬇ input: max_len = Hard cap on output length. Must be > 1 because we always prepend <sos>.
⬇ input: sos_id = Start-of-sequence token id. First target column is filled with this — it's what the decoder conditions on to produce token #1.
⬇ input: eos_id = End-of-sequence token id. We stop early if EVERY batch element emits eos.
⬇ input: src_mask = Source padding mask, passed through to encode() and cross-attn.
⬆ returns = (B, <=max_len) LongTensor of generated ids, including <sos>.
14self.eval()

Switch to eval mode. Disables dropout (which would add random noise to decoding) and puts BatchNorm/LayerNorm in eval mode. Critical for deterministic greedy decoding.

EXECUTION STATE
📚 self.eval() = Sets self.training = False recursively. Modules check self.training to branch: Dropout becomes identity; BatchNorm uses running stats; LayerNorm is unaffected but is still fine to call.
15device = src.device

Grab the device (cpu or cuda:k) of the input. We'll build any new tensors (ys, causal mask) on the same device to avoid host-device copies.

16B = src.size(0)

Batch size. We decode all B sequences in parallel.

19memory = self.encode(src, src_mask=src_mask)

Encoder runs ONCE. Memory is the same for every decoder step — no need to recompute.

EXECUTION STATE
→ why once? = src doesn't change during decoding, so its encoder output is constant. This is the single biggest efficiency trick for encoder-decoder inference.
22ys = torch.full((B, 1), sos_id, dtype=torch.long, device=device)

Seed the output with <sos>. ys starts as (B, 1) and grows by one column per step.

EXECUTION STATE
📚 torch.full(size, fill_value, dtype, device) = Creates a tensor of given size filled with fill_value. Like torch.zeros but with a custom constant.
⬇ arg: size = (B, 1) = One sos per batch element.
⬇ arg: fill_value = sos_id = The start token's integer id.
⬇ arg: dtype = torch.long = int64 — required for embedding lookups (nn.Embedding only accepts LongTensor).
⬇ arg: device = device = Same device as src so we avoid cpu-gpu copies.
24for _ in range(max_len - 1)

Loop up to max_len-1 times (we already have <sos> accounting for 1). If a step emits eos for every batch element we break early.

LOOP TRACE · 3 iterations
step 0
ys before = (B, 1) — just <sos>
T = 1
ys after = (B, 2) — <sos> + first predicted token
step 1
ys before = (B, 2)
T = 2
ys after = (B, 3)
step k
ys before = (B, k+1)
T = k+1
ys after = (B, k+2)
break? = If next_id == eos for ALL B rows, break. Otherwise continue.
25T = ys.size(1)

Current target length (grows by 1 each step).

26causal = torch.triu(ones(T, T, device=device) * -inf, diagonal=1)

Build a fresh causal mask of size T x T. We rebuild each step because T grows.

EXECUTION STATE
📚 torch.triu(input, diagonal=1) = Keeps entries strictly above the diagonal; zeros out the rest. With input full of -inf, we get -inf above diagonal, 0 elsewhere — exactly the additive causal mask.
⬇ arg: input = ones(T,T,device)*-inf = Full -inf matrix on the right device.
⬇ arg: diagonal = 1 = Skip the main diagonal so position t CAN attend to itself, but NOT to t+1, t+2, ...
30hidden = self.decode(ys, memory, tgt_mask=causal, memory_mask=src_mask)

One decoder forward pass on the full partial sequence. NOTE: without KV-caching (ch09) this recomputes attention for all previous positions every step — O(T^2) per step, O(T^3) total.

33step_logits = self.output_proj(hidden[:, -1])

Only the LAST position's hidden vector matters — it represents the next token prediction.

EXECUTION STATE
hidden[:, -1] = (B, d_model) — last target position's hidden state.
self.output_proj(...) = Linear from d_model -> tgt_vocab.
⬆ step_logits = (B, tgt_vocab)
34next_id = step_logits.argmax(dim=-1, keepdim=True)

Greedy pick: choose the single highest-scoring vocab id. For beam search / sampling, see ch09.

EXECUTION STATE
📚 .argmax(dim=-1, keepdim=True) = Returns the index of the max value along the given dim. keepdim=True preserves the reduced axis as size 1 so shapes line up for torch.cat.
⬇ arg: dim = -1 = Reduce over the vocab dimension (last axis).
⬇ arg: keepdim = True = (B, tgt_vocab) -> (B, 1) instead of (B,) so we can concat.
⬆ next_id = (B, 1) LongTensor of predicted next-token ids.
35ys = torch.cat([ys, next_id], dim=1)

Append the newly predicted column. ys grows from (B, T) to (B, T+1).

EXECUTION STATE
📚 torch.cat(tensors, dim) = Concatenate a list of tensors along the given dim. Sizes must match on all other dims.
⬇ arg: tensors = [ys, next_id] = (B, T) and (B, 1). Match on dim 0.
⬇ arg: dim = 1 = Concatenate along sequence length.
37if (next_id == eos_id).all(): break

Early-stop when EVERY batch element has emitted eos. .all() returns True only if all elements of the boolean tensor are True.

EXECUTION STATE
📚 .all() = Reduces a boolean tensor to a single Python bool: True iff every entry is True.
(next_id == eos_id) = (B, 1) BoolTensor — True where the prediction was eos.
40return ys

Final output: (B, <=max_len) LongTensor containing <sos> + generated ids.

EXECUTION STATE
⬆ return = In translation, you'd now strip <sos> (and everything past eos) and decode via the target tokenizer.
25 lines without explanation
1@torch.no_grad()
2    def generate(
3        self,
4        src: torch.Tensor,
5        max_len: int,
6        sos_id: int,
7        eos_id: int,
8        src_mask: Optional[torch.Tensor] = None,
9    ) -> torch.Tensor:
10        """Greedy autoregressive decoding. O(max_len * forward).
11
12        For beam search, top-k / top-p sampling, and KV-cache acceleration,
13        see chapter 9.
14        """
15        self.eval()
16        device = src.device
17        B = src.size(0)
18
19        # 1) Encode source ONCE. Memory is reused for every step.
20        memory = self.encode(src, src_mask=src_mask)
21
22        # 2) Start with <sos> for every batch element.
23        ys = torch.full((B, 1), sos_id, dtype=torch.long, device=device)
24
25        for _ in range(max_len - 1):
26            T = ys.size(1)
27            causal = torch.triu(
28                torch.ones(T, T, device=device) * float("-inf"),
29                diagonal=1,
30            )
31            hidden = self.decode(ys, memory,
32                                 tgt_mask=causal,
33                                 memory_mask=src_mask)
34            step_logits = self.output_proj(hidden[:, -1])   # (B, tgt_vocab)
35            next_id = step_logits.argmax(dim=-1, keepdim=True)  # (B, 1)
36            ys = torch.cat([ys, next_id], dim=1)
37
38            if (next_id == eos_id).all():
39                break
40
41        return ys
Forward reference: greedy decoding is the cheapest, least-diverse strategy. Chapter 9 covers beam search (keep top-kk partial hypotheses), sampling (temperature, top-kk, top-pp, typical decoding), and KV-caching (reuse past K,VK, V so each step is O(T)O(T) instead of O(T2)O(T^2)).

Parameter Count

Let's sanity-check the model size for the ch13 Multi30k config: dmodel=512,H=8,dff=2048,N=6,Vsrc=7000,Vtgt=5500d_{model}=512, H=8, d_{ff}=2048, N=6, V_{src}=7000, V_{tgt}=5500. We count each learnable piece.

ComponentFormulaParameters
One encoder layer (MHSA + FFN + 2 LN)4*(d*d+d) + d*dff + dff + dff*d + d + 2*2*d3,152,384
Encoder stack (N=6)6 x encoder layer18,914,304
One decoder layer (MHSA + Cross + FFN + 3 LN)2 * 4*(d*d+d) + d*dff + dff + dff*d + d + 3*2*d4,204,032
Decoder stack (N=6)6 x decoder layer25,224,192
Source embeddingV_src * d_model = 7000 * 5123,584,000
Target embeddingV_tgt * d_model = 5500 * 5122,816,000
Output projectiond_model * V_tgt + V_tgt2,821,500
Totalsum of rows above53,359,996 (~53.4M)

About 53M parameters — in the same ballpark as the original Vaswani base model (65M), a bit smaller because our vocabularies are smaller than WMT's. For reference, Multi30k is a low-resource dataset and you'll usually train a smaller variant (dmodel=256,N=3d_{model}=256, N=3, closer to 10–15M parameters) to avoid overfitting.

Two memory-saving tricks cut this further:

  • Weight tying Wout=EtgtW_{\text{out}} = E_{tgt}: removes the 2.82M output-projection parameters (the two matrices are shape (Vtgt,dmodel)(V_{tgt}, d_{model}) and (dmodel,Vtgt)(d_{model}, V_{tgt}) — each other's transpose).
  • Shared source+target embedding (when vocabularies are merged via joint BPE): further saves min(Vsrc,Vtgt)dmodel\min(V_{src}, V_{tgt}) \cdot d_{model} parameters.

Interactive Visualization

The visualization below animates the full pipeline on a toy German→English pair, \text{&quot;Der Hund läuft&quot;} \to \text{&quot;The dog runs&quot;}. Watch the source tokens flow through the encoder stack, deposit as memory blocks, and then get queried by each decoder step through glowing cross-attention lines. Hover a line to see its (illustrative) attention weight.

Loading full Transformer visualization...
What to look for: (1) the encoder colors intensify as all four source tokens flow through; (2) memory blocks crystallize once the encoder finishes; (3) each target token's cross-attention lines concentrate on its aligned source word — "The"→"Der", "dog"→"Hund", "runs"→"läuft". These alignment patterns are the hallmark of a trained translation model.

In Modern Systems

Encoder-decoder: T5, BART, mBART, NLLB

The original Vaswani architecture is still the right default for sequence-to-sequence tasks where the full input is available up front and the output is a separate sequence. T5 (Raffel et al. 2019) reframes every NLP task as text-to-text and trains a shared encoder-decoder on a denoising objective. BART (Lewis et al. 2019) uses a similar shape for summarization and infilling. NLLB-200 (Meta 2022) scales mBART to 200-language translation — still an encoder-decoder, because the decoder wants to attend to the FULL source regardless of where it is in generation.

Decoder-only: GPT, Llama, Mistral, PaLM

Modern general-purpose LLMs drop the encoder entirely. They are single decoder stacks with masked self-attention only — no cross-attention, no separate vocabulary. Input and output are concatenated into one token stream, and training is plain next-token prediction on web-scale corpora. Inference splits into prefill (process the prompt in parallel) and decode (generate tokens one at a time with KV-caching). This architecture won for general LLMs because (1) training data is unlabeled text, no source/target split needed; (2) one stack is cheaper than two; (3) the prefill/decode split maps cleanly onto GPU systems.

Encoder-only: BERT, RoBERTa

Drop the decoder instead, keep the encoder. Useful for classification, retrieval, and sentence embeddings — tasks where you want a rich representation, not a generated sequence.

When each wins

Task familyBest architectureRepresentative models
General LM / chat / codeDecoder-onlyGPT-4, Llama-3, Mistral, Claude
Machine translation (esp. many-to-many)Encoder-decoderNLLB, M2M-100, mBART
Summarization, infillingEncoder-decoderBART, Flan-T5, PEGASUS
Classification, retrieval, embeddingsEncoder-onlyBERT, RoBERTa, E5, BGE
Multimodal captioning, speechEncoder-decoderWhisper, BLIP-2, Pix2Struct

Recent research has revived encoder-decoder models for their efficiency on supervised seq2seq tasks — Flan-T5, UL2 (Google 2022), and PaLM-2's mixture-of-encoders suggest the split is alive, even as decoder-only dominates general chat.


Summary

  1. The full Transformer is three stages: source → memory (encoder, once), target + memory → hidden (decoder, one step per output token at inference), hidden → logits (linear).
  2. Training is one parallel forward pass over the whole target using teacher forcing + causal mask. Inference is an autoregressive loop; KV-caching (ch09) makes it affordable.
  3. Source and target vocabularies can be separate (our Multi30k setup) or shared (joint BPE, common for same-script pairs). Weight tying Wout=EtgtW_{\text{out}} = E_{tgt} removes a full embedding-sized matrix of parameters for free.
  4. Modern taxonomy: encoder-decoder for seq2seq, decoder-only for general LLMs, encoder-only for classification/embeddings. Each shape exists because it fits a specific family of tasks.

Chapter Summary

Over the six sections of this chapter, we built the decoder from scratch:

  1. Section 1 — Decoder Architecture Overview. The three-sublayer pattern (masked self-attn + cross-attn + FFN) and why the decoder exists.
  2. Section 2 — Causal Masking. The triangular -\infty mask that lets teacher forcing pretend to be autoregressive.
  3. Section 3 — Cross-Attention. Queries from the target side, keys and values from the encoder memory; how the decoder "looks at" the source.
  4. Section 4 — Implementing TransformerDecoderLayer. The three sublayers wired together with Add&Norm, dropout placement, and the pre-norm vs post-norm debate.
  5. Section 5 — Complete Transformer Decoder. Stacking NN decoder layers with target embedding, positional encoding, and output projection. Weight tying and depth-scaling.
  6. Section 6 — Full Encoder-Decoder Transformer (this section). Glue encoder + decoder into one module with a unified forward pass and a greedy generate loop.

You now have every piece of the Vaswani 2017 architecture implemented end-to-end. The rest of the book builds on this foundation: chapter 9 adds smarter decoding, chapter 10 adds training infrastructure, and chapter 13 walks through training this exact model on Multi30k until it translates real German sentences.


Exercises

Easy

1. Modify the Transformer\text{Transformer} constructor to accept separate num_enc_layers\text{num\_enc\_layers} and num_dec_layers\text{num\_dec\_layers}. Re-count the parameters for the Multi30k config with 4 encoder layers and 6 decoder layers.

Medium

2. Implement weight tying by setting self.output_proj.weight=self.tgt_embed.weight\text{self.output\_proj.weight} = \text{self.tgt\_embed.weight}. Verify that training still runs (both tensors must have the same shape and dtype). Measure the parameter-count reduction and sanity-check the initial logits distribution.

3. Extend generate\text{generate} to return attention weights from the LAST decoder layer's cross-attention so you can visualize source↔target alignments after translation.

Hard

4. Replace the encoder-decoder with a decoder-only variant: concatenate src\text{src} and tgt\text{tgt} with a separator token and train next-token prediction on the concatenation. Measure translation quality on Multi30k — does the encoder pull its weight, or can a decoder-only beat it at this scale?


Next Section Preview

Chapter 9 picks up exactly where generate\text{generate} left off. We look at greedy decoding, beam search, sampling (temperature, top-kk, top-pp), and KV-caching — the collection of tricks that turn the naive O(T3)O(T^3) generation loop above into the real-time autoregressive systems powering modern LLMs.

Loading comments...