Over the last seven chapters we built every component of the Transformer piece by piece. Chapter 7 finished the encoder — a stack of N identical self-attention blocks that turns a source sentence into a bank of contextualized vectors we will call the memory. Section 5 of this chapter finished the decoder — a stack of N identical blocks that does masked self-attention, cross-attention against the memory, and feed-forward projection.
This section glues them together. We wrap the two stacks, plus the source & target embedding tables and the final vocabulary projection, inside a single Transformer module whose forward(src,tgt) runs end-to-end in one call. This is exactly the Vaswani et al. (2017) architecture — no extras, no missing pieces.
What changed in 2017: before the Transformer, sequence-to-sequence models were encoder-decoder RNNs. Even attention-augmented variants (Bahdanau 2014, Luong 2015) still stepped through the sequence one token at a time. The Transformer replaces recurrence with pure attention + FFN, so encoder and decoder are batched matrix multiplies — trivially parallelizable on GPUs and the reason we can now train trillion-parameter models.
Architecture Overview
At the highest level, the full Transformer is a three-stage pipeline. Source tokens flow through the encoder once; the resulting memory is then reused by every decoder step.
The encoder runs exactly once per input sentence. During inference, its output (the memory) is computed and then reused across every decoder step.
The decoder's cross-attention gets its K,V from memory and its Q from the decoder's own self-attention output. That is how target tokens "look at" the source.
The final projection has shape (dmodel,Vtgt), producing one logit per target vocab id at every position.
Forward Pass Equations
Written compactly, the end-to-end forward pass is just two lines. Let src∈ZB×Tsrc be source token ids and tgt∈ZB×Ttgt target token ids. Then
memory=Encoder(src,src_mask), and logits=Decoder(tgt,memory,tgt_mask,src_mask).
Here memory∈RB×Tsrc×dmodel and logits∈RB×Ttgt×Vtgt. The source mask src_mask (a key-padding mask) is reused in both the encoder's self-attention and the decoder's cross-attention. The target mask tgt_mask is the Ttgt×Ttgt causal mask from §2.
Internally, Decoder expands toh0=Dropout(Embed(tgt)dmodel+PE),hℓ=DecoderLayerℓ(hℓ−1,memory,tgt_mask,src_mask)for ℓ=1,…,N, and finallylogits=hNWout⊤. The Encoder expands analogously (§5 of ch07).
Training vs Inference
Training: teacher forcing, one forward pass per batch
During training we have the full target sequence y=(y1,…,yTtgt) in hand. We feed tgt_in=[⟨sos⟩,y1,…,yTtgt−1] into the decoder and ask it to predict tgt_out=[y1,…,yTtgt−1,⟨eos⟩]. The causal mask ensures that position t cannot peek at positions >t. One forward pass computes logits for ALL target positions simultaneously, and the loss is the mean cross-entropy between logits and tgt_out.
Why teacher forcing works: the causal mask turns what is morally Ttgt separate autoregressive problems into one parallel matrix computation. Every target position sees only its valid left context, exactly as it would during inference — but training throughput is multiplied by Ttgt.
Inference: autoregressive loop
At test time we don't have y. We start with ⟨sos⟩ and generate one token per step until we hit ⟨eos⟩ or t=Tmax. Each step runs the FULL decoder on the partial sequence, takes the last position's logits, picks a next token, and appends it.
This naive loop is O(Ttgt3) in work and O(Ttgt2) in peak memory, because step t recomputes attention for all previous positions. Chapter 9 introduces beam search, temperature / top-k / top-p sampling, and — critically — KV-caching, which stores the K,V of previous positions so step t only pays O(Ttgt) work.
Shared vs Separate Vocabularies
The Transformer has two embedding tables: Esrc∈RVsrc×dmodel and Etgt∈RVtgt×dmodel, plus an output projection Wout∈RVtgt×dmodel. You have two design axes:
Shared source & target embedding (Esrc=Etgt): sensible when the two languages share a script and you train a joint BPE vocabulary over both corpora. Common for En↔De, En↔Fr in the original paper, and for code-completion models where input and output come from the same token space.
Weight tying (Wout=Etgt): ties the output projection to the target-embedding table (Press & Wolf 2017). Saves Vtgt⋅dmodel parameters and usually improves perplexity, because both matrices encode the same "which direction in hidden space represents token v?" relation.
Separate vocabularies: required when scripts or tokenizers differ (En↔Zh, En↔Ar). Here a shared BPE would waste capacity on tokens that never co-occur.
For the Multi30k translation project in chapter 13 we use separate source (de) and target (en) BPE vocabularies — Vsrc≈7000 and Vtgt≈5500. That is the setting the code below targets.
Plain-Python End-to-End
Before introducing the Transformer module, let's watch the shapes move by hand. We use the shared chapter config — B=1,Tsrc=4,Ttgt=3,dmodel=8,H=2,dff=16,N=2, Vsrc=12,Vtgt=10, and torch.manual_seed(0). The actual printed values below were computed by running this code.
Plain-Python End-to-End Forward Pass (toy shapes)
🐍toy_transformer_forward.py
Explanation(34)
Code(57)
1import torch
PyTorch's top-level package. Gives us tensors, autograd, random-number generation, and the neural-network toolbox we build on top of.
EXECUTION STATE
torch = Core tensor library. Provides torch.Tensor, torch.randint, torch.zeros, torch.arange, torch.triu, torch.manual_seed — every primitive we need below.
2import torch.nn as nn
torch.nn holds the building blocks: nn.Embedding, nn.Linear, nn.TransformerEncoderLayer, nn.TransformerDecoderLayer. Every learnable piece of the Transformer we assemble lives here.
EXECUTION STATE
nn = Short alias for torch.nn. Used as nn.Embedding(vocab, d_model), nn.Linear(in, out), etc.
3import torch.nn.functional as F
Functional (stateless) API. We don't call F explicitly in this snippet but it's the standard import used when we call e.g. F.softmax(logits, dim=-1) after this cell.
EXECUTION STATE
F = Stateless counterparts of nn modules. F.softmax, F.relu, F.log_softmax — no parameters to own, pure functions.
4import math
Python standard-library math module. We need math.sqrt(d_model) for embedding scaling and math.log(10000.0) for sinusoidal positional encoding.
EXECUTION STATE
math.sqrt(8) = 2.8284... used to scale embeddings per the original paper
7torch.manual_seed(0)
Sets the global RNG seed so all random draws below (token ids, embedding weights, layer weights) are reproducible. Same seed used throughout §4 and §5 so readers can compare values across sections.
EXECUTION STATE
📚 torch.manual_seed(seed) = Seeds PyTorch's CPU RNG. Affects torch.randint, nn.Embedding weight init, nn.Linear weight init — everything that draws randomness later in the program.
⬇ arg: seed = 0 = Chosen to match §4/§5 so the same memory[0,0,:4] values come out here.
8B, T_src, T_tgt = 1, 4, 3
Shared tiny shapes. Batch=1, source length 4 tokens, target length 3 tokens. Small enough that every intermediate tensor is inspectable.
EXECUTION STATE
B = 1 = Batch size. Only one sentence pair, so every shape starts with 1.
T_src = 4 = Source sequence length. Represents e.g. a 4-token German sentence after tokenization.
T_tgt = 3 = Target sequence length. Represents e.g. a 3-token English translation.
9d_model, H, d_ff = 8, 2, 16
Toy model dims. d_model=8 is the hidden size, H=2 heads (so d_k=4 per head), d_ff=16 is the FFN inner dim. Real models use 512/8/2048 (ch13).
EXECUTION STATE
d_model = 8 = Width of every token vector throughout the stack. Input, output, and memory are all [B, T, 8].
H = 2 = Number of attention heads. d_k = d_model / H = 4 per head.
d_ff = 16 = Inner width of the FFN sublayer. Typically 4 * d_model. For d_model=8 we use 16.
10N = 2
Number of stacked layers in BOTH encoder and decoder. Vaswani 2017 used N=6; we use 2 so the walkthrough is cheap.
Two SEPARATE vocabularies: 12 source types (de) and 10 target types (en). That is why we need two embedding tables — ids from different vocabularies are not comparable.
EXECUTION STATE
src_vocab = 12 = Size of the source-side vocabulary. Source token ids live in [0, 12).
tgt_vocab = 10 = Size of the target-side vocabulary. Target token ids live in [0, 10). Output logits will have 10 classes.
14src = torch.randint(1, src_vocab, (B, T_src))
Draws a random source sentence of 4 token ids in [1, 12). We skip id 0 because that is conventionally a <pad> token.
EXECUTION STATE
📚 torch.randint(low, high, size) = Draws integers uniformly from [low, high). Returns a tensor of given size. Example: torch.randint(1, 5, (2,)) -> [2, 4].
⬇ arg: low = 1 = Inclusive lower bound. We skip 0 (reserved for <pad>).
Random target sentence of 3 ids. In real training these come from the dataloader; here random is fine because we only want to demonstrate the forward pass plumbing.
EXECUTION STATE
⬆ result: tgt =
tensor([[8, 7, 8]]) shape (1, 3)
18src_embed = nn.Embedding(src_vocab, d_model)
Source-side lookup table. Row i of its weight matrix is the 8-dim vector for source token id i. Total parameters: src_vocab * d_model = 12 * 8 = 96.
EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim) = Stores a learnable (num_embeddings, embedding_dim) matrix. Calling it with a LongTensor of ids selects the corresponding rows. Faster than doing a one-hot-then-matmul.
⬇ arg: num_embeddings = 12 = How many rows the table has — one row per src vocab id.
⬇ arg: embedding_dim = 8 = Width of each row = d_model. Matches the rest of the stack so shapes line up.
Builds the fixed sinusoidal positional-encoding table from Vaswani 2017. Returns a (max_len, d) matrix where row p is the positional signature for position p.
EXECUTION STATE
⬇ input: max_len = 16 = How many positions to precompute. We only use the first T_src or T_tgt rows.
⬇ input: d = 8 = Same as d_model so PE can be added directly to embeddings.
⬆ returns = torch.Tensor (16, 8) — fixed, not learnable.
23pe = torch.zeros(max_len, d)
Allocate a (16, 8) zero matrix we will fill with sin/cos values.
Column vector [0,1,2,...,15]^T of floats. .unsqueeze(1) turns shape (16,) into (16, 1) so we can broadcast against div of shape (4,).
EXECUTION STATE
📚 .unsqueeze(1) = Inserts a new axis of size 1 at position 1. Shape (16,) -> (16, 1). Enables broadcasting: (16, 1) * (4,) -> (16, 4).
pos = (16, 1) column vector 0..15
25div = exp(arange(0, d, 2) * (-log(10000)/d))
Frequency terms. For d=8 we need 4 frequencies (even dims 0,2,4,6). Higher dims -> lower frequencies. This gives us positions encoded at multiple scales.
EXECUTION STATE
torch.arange(0, 8, 2) = tensor([0, 2, 4, 6]) — even indices only
-math.log(10000.0) / 8 = -1.1513 ...
div = (4,) tensor of decreasing frequency scales
26pe[:, 0::2] = torch.sin(pos * div)
Fills even columns with sin(pos * freq). pos(16,1) * div(4,) broadcasts to (16, 4), matching pe[:, 0::2] which is also (16, 4).
EXECUTION STATE
0::2 = Python slice: start=0, stop=end, step=2 — every EVEN index.
pe[:, 0::2].shape = (16, 4) — the 4 even columns of pe.
27pe[:, 1::2] = torch.cos(pos * div)
Fills odd columns with cos(pos * freq). So even dims use sin, odd dims use cos — this is the classic sinusoidal scheme.
28return pe
Hand back the (16, 8) position table.
EXECUTION STATE
⬆ return: pe = (16, 8) fixed sinusoidal positional encoding
30pe = make_pe(16, d_model)
Build the table once; we slice it to T_src or T_tgt rows as needed.
Look up source embeddings, scale by sqrt(d_model)=2.828 (Vaswani trick — keeps embedding magnitude comparable to PE), then add positional encoding for positions 0..T_src-1.
EXECUTION STATE
src_embed(src) = Shape (1, 4, 8). Row i is src_embed.weight[src[0,i]].
* math.sqrt(d_model) = Scalar multiply by 2.8284. Scales embeddings up so adding PE (whose values are in [-1, 1]) doesn't dominate.
pe[:T_src] = First 4 rows of pe -> shape (4, 8).
.unsqueeze(0) = (4, 8) -> (1, 4, 8) so it broadcasts against (1, 4, 8) embeddings.
Builds the (3, 3) additive causal mask used for target self-attention. Entries strictly above the diagonal become -inf; diagonal and below stay 0.
EXECUTION STATE
📚 torch.triu(input, diagonal=k) = Returns the upper-triangular part of input, zeroing everything below diagonal k. diagonal=1 skips the main diagonal itself.
⬇ arg: input = ones(3,3)*-inf = A 3x3 matrix full of -inf.
⬇ arg: diagonal = 1 = Keep values strictly above the main diagonal; everything on and below becomes 0.
⬆ causal (3, 3) =
[[0, -inf, -inf],
[0, 0 , -inf],
[0, 0 , 0 ]]
50h = decoder(tgt_x, memory, tgt_mask=causal)
Run target through the decoder stack, using memory as cross-attention K/V and causal as the self-attention mask.
Output projection from hidden space (8) to target vocabulary (10). No bias — matches the common 'weight tying' variant where this W can share weights with tgt_embed (Press & Wolf 2017).
EXECUTION STATE
📚 nn.Linear(in, out, bias=False) = y = x @ W.T, no bias term. Parameters = in * out = 8 * 10 = 80.
⬇ arg: in_features = 8 = Matches d_model.
⬇ arg: out_features = 10 = Matches tgt_vocab. One logit per vocab id.
⬇ arg: bias = False = Standard for vocabulary heads (saves 10 parameters, no harm).
54logits = W_out(h)
Per-position projection: (1, 3, 8) @ (8, 10) -> (1, 3, 10). logits[0, t, v] is the score the model assigns to vocab id v at target position t.
→ next step = Training: compare to tgt with F.cross_entropy. Inference: take argmax over the last dim to pick a token (ch09 covers beam search / sampling).
56print('memory[0,0,:4] =', ...)
Sanity print. Confirms encoder produced finite values.
→ argmax = argmax = 3 -> predicted vocab id 3 for the first target step.
23 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
56# ---- Shared tiny config (see §4 of this chapter) ----7torch.manual_seed(0)8B, T_src, T_tgt =1,4,39d_model, H, d_ff =8,2,1610N =211src_vocab, tgt_vocab =12,101213# Random source / target token ids14src = torch.randint(1, src_vocab,(B, T_src))# [[11, 8, 2, 4]]15tgt = torch.randint(1, tgt_vocab,(B, T_tgt))# [[8, 7, 8]]1617# Separate embedding tables (de vs en vocabularies)18src_embed = nn.Embedding(src_vocab, d_model)19tgt_embed = nn.Embedding(tgt_vocab, d_model)2021# Sinusoidal positional encoding (fixed, not learned)22defmake_pe(max_len, d):23 pe = torch.zeros(max_len, d)24 pos = torch.arange(0, max_len).float().unsqueeze(1)25 div = torch.exp(torch.arange(0, d,2).float()*(-math.log(10000.0)/ d))26 pe[:,0::2]= torch.sin(pos * div)27 pe[:,1::2]= torch.cos(pos * div)28return pe
2930pe = make_pe(16, d_model)3132# Stage 1: source side -> encoder33src_x = src_embed(src)* math.sqrt(d_model)+ pe[:T_src].unsqueeze(0)3435enc_layer = nn.TransformerEncoderLayer(36 d_model=d_model, nhead=H, dim_feedforward=d_ff,37 batch_first=True, activation='relu')38encoder = nn.TransformerEncoder(enc_layer, num_layers=N)39memory = encoder(src_x)# [1, 4, 8]4041# Stage 2: target side -> decoder (queries memory via cross-attn)42tgt_x = tgt_embed(tgt)* math.sqrt(d_model)+ pe[:T_tgt].unsqueeze(0)4344dec_layer = nn.TransformerDecoderLayer(45 d_model=d_model, nhead=H, dim_feedforward=d_ff,46 batch_first=True, activation='relu')47decoder = nn.TransformerDecoder(dec_layer, num_layers=N)4849causal = torch.triu(torch.ones(T_tgt, T_tgt)*float('-inf'), diagonal=1)50h = decoder(tgt_x, memory, tgt_mask=causal)# [1, 3, 8]5152# Stage 3: project decoder hidden states to vocabulary logits53W_out = nn.Linear(d_model, tgt_vocab, bias=False)54logits = W_out(h)# [1, 3, 10]5556print("memory[0,0,:4] =", memory[0,0,:4].tolist())57print("logits[0,0] =",[round(x,3)for x in logits[0,0].tolist()])
Running this prints:
📝text
1memory[0,0,:4] = [0.7302, 0.3295, -0.9525, 1.4986]
2logits[0,0] = [0.361, -0.603, 0.058, 0.903, 0.147, -0.001, -0.121, -0.979, 0.720, 0.566]
3argmax logits[0,0] = 3 # predicted vocab id for the first target position
The shape chain is exactly what the architecture diagram promised: source ids (1,4)→ memory (1,4,8)→ decoder hidden (1,3,8)→ logits (1,3,10). Every sublayer we wrote from scratch in the previous chapters is hidden inside those two stack calls.
PyTorch Implementation
The Transformer module
Now the real thing — a single nn.Module that owns every learnable piece.
class Transformer(nn.Module) — wire encoder + decoder + embeddings
🐍transformer.py
Explanation(40)
Code(94)
1import math
math.sqrt for embedding scaling and math.log for sinusoidal frequencies.
2import torch
Core tensor library. Provides torch.Tensor, torch.arange, torch.zeros — all the primitives used in _build_pe.
3import torch.nn as nn
nn.Module base class, nn.Embedding, nn.Linear, nn.Dropout, nn.TransformerEncoder{,Layer}, nn.TransformerDecoder{,Layer}.
4from typing import Optional
Lets us type mask arguments as Optional[torch.Tensor] — either a tensor or None.
7class Transformer(nn.Module)
Our module subclasses nn.Module so PyTorch can auto-track parameters, move them to GPU, save/load state dicts, and plug into DataParallel etc.
EXECUTION STATE
📚 nn.Module = Base class for neural net modules. Subclasses implement __init__ (register sublayers) and forward (compute). Parameters are discovered automatically via setattr.
Constructor signature. Accepts both vocabulary sizes (two separate vocabs) and the standard Vaswani hyperparameters, with defaults matching the original paper.
EXECUTION STATE
⬇ src_vocab (int) = Size of source vocab (e.g. 7000 for Multi30k de BPE). Sets number of rows in self.src_embed.weight.
⬇ tgt_vocab (int) = Size of target vocab (e.g. 5500 for Multi30k en BPE). Sets rows in tgt_embed AND out_features of output_proj.
⬇ d_model = 512 = Hidden dim throughout. Every token is a 512-dim vector.
⬇ num_heads = 8 = Attention heads per layer. d_k = 512/8 = 64 per head.
⬇ num_layers = 6 = Stack depth — applied SYMMETRICALLY to encoder and decoder in this implementation.
⬇ max_len = 5000 = Precompute PE for up to 5000 positions. Caps the maximum sentence length the model can handle without rebuilding.
⬇ dropout = 0.1 = Applied after embedding+PE and inside each encoder/decoder sublayer. 0.1 is Vaswani default; 0.3 for lower-resource setups like Multi30k.
28super().__init__()
Runs nn.Module's __init__ so parameter tracking, hooks, and children registry get set up BEFORE we assign submodules.
29self.d_model = d_model
Stash d_model so encode() and decode() can scale embeddings by sqrt(d_model).
Stores PE as a BUFFER (not a Parameter). Buffers move with .to(device) and get saved in state_dict, but they are NOT updated by the optimizer — exactly what we want for fixed sinusoidal PE.
EXECUTION STATE
📚 register_buffer(name, tensor) = Registers a non-trainable tensor with the module. Accessible as self.pe afterwards. Appears in state_dict by default.
⬇ arg: name = 'pe' = Attribute name; we'll read it as self.pe later.
Applied after embedding+PE addition on both encoder and decoder inputs (Vaswani 2017, §5.4).
EXECUTION STATE
📚 nn.Dropout(p) = In training mode, sets each element to 0 with probability p and scales surviving ones by 1/(1-p). In eval mode (model.eval()), it's a no-op.
Builds ONE encoder block: masked self-attn + FFN wrapped in Add&Norm. We reuse PyTorch's well-tested implementation here; ch07 built the same thing from scratch.
⬇ arg: out_features = tgt_vocab = One logit per target vocab id.
⬇ arg: bias = False = Convention for vocab heads; also a prerequisite for weight-tying output_proj.weight = tgt_embed.weight.
52@staticmethod
Marks _build_pe as not needing self — it's a pure function. Lets us build PE before any instance state exists if we want.
53def _build_pe(max_len, d_model)
Same sinusoidal table construction as in the plain-Python walkthrough above. Returns shape (1, max_len, d_model) so it broadcasts against (B, T, d_model) directly.
EXECUTION STATE
⬇ input: max_len = How many positions to precompute.
⬇ input: d_model = Must match the hidden dim used throughout.
Runs all N encoder layers. src_key_padding_mask (B, T_src) tells each self-attention sublayer which source positions are padding and should receive -inf scores.
EXECUTION STATE
📚 src_key_padding_mask = Bool mask of shape (B, T_src). True -> position is PAD -> zero out its attention contribution. This is PyTorch's convention (True = ignore).
PE as a buffer, not a Parameter. We want PE to travel with the module via .to(device) and state_dict, but we do NOT want the optimizer to update it.
src_mask is fed twice: once as src_key_padding_mask for the encoder, again as memory_key_padding_mask for the decoder's cross-attention. Both places need to know which source positions are padding.
No softmax in forward. We return raw logits so that F.cross_entropy can call log-softmax internally with better numerics.
Greedy generation
For the simplest possible inference, we pair forward with a generate method that loops over steps. This belongs on the same module so it can reuse self.encode / self.decode.
Decorator that disables autograd for this method. During inference we don't need gradients, and skipping them saves memory and speed.
EXECUTION STATE
📚 torch.no_grad = Context manager / decorator that sets requires_grad=False for all ops inside. Same effect as wrapping the body in `with torch.no_grad(): ...`.
Greedy decoding loop. At each step: run forward, take argmax of the LAST position's logits, append, repeat until max_len or eos.
EXECUTION STATE
⬇ input: src = (B, T_src) source token ids.
⬇ input: max_len = Hard cap on output length. Must be > 1 because we always prepend <sos>.
⬇ input: sos_id = Start-of-sequence token id. First target column is filled with this — it's what the decoder conditions on to produce token #1.
⬇ input: eos_id = End-of-sequence token id. We stop early if EVERY batch element emits eos.
⬇ input: src_mask = Source padding mask, passed through to encode() and cross-attn.
⬆ returns = (B, <=max_len) LongTensor of generated ids, including <sos>.
14self.eval()
Switch to eval mode. Disables dropout (which would add random noise to decoding) and puts BatchNorm/LayerNorm in eval mode. Critical for deterministic greedy decoding.
EXECUTION STATE
📚 self.eval() = Sets self.training = False recursively. Modules check self.training to branch: Dropout becomes identity; BatchNorm uses running stats; LayerNorm is unaffected but is still fine to call.
15device = src.device
Grab the device (cpu or cuda:k) of the input. We'll build any new tensors (ys, causal mask) on the same device to avoid host-device copies.
16B = src.size(0)
Batch size. We decode all B sequences in parallel.
19memory = self.encode(src, src_mask=src_mask)
Encoder runs ONCE. Memory is the same for every decoder step — no need to recompute.
EXECUTION STATE
→ why once? = src doesn't change during decoding, so its encoder output is constant. This is the single biggest efficiency trick for encoder-decoder inference.
Build a fresh causal mask of size T x T. We rebuild each step because T grows.
EXECUTION STATE
📚 torch.triu(input, diagonal=1) = Keeps entries strictly above the diagonal; zeros out the rest. With input full of -inf, we get -inf above diagonal, 0 elsewhere — exactly the additive causal mask.
⬇ arg: input = ones(T,T,device)*-inf = Full -inf matrix on the right device.
⬇ arg: diagonal = 1 = Skip the main diagonal so position t CAN attend to itself, but NOT to t+1, t+2, ...
One decoder forward pass on the full partial sequence. NOTE: without KV-caching (ch09) this recomputes attention for all previous positions every step — O(T^2) per step, O(T^3) total.
33step_logits = self.output_proj(hidden[:, -1])
Only the LAST position's hidden vector matters — it represents the next token prediction.
Greedy pick: choose the single highest-scoring vocab id. For beam search / sampling, see ch09.
EXECUTION STATE
📚 .argmax(dim=-1, keepdim=True) = Returns the index of the max value along the given dim. keepdim=True preserves the reduced axis as size 1 so shapes line up for torch.cat.
⬇ arg: dim = -1 = Reduce over the vocab dimension (last axis).
⬇ arg: keepdim = True = (B, tgt_vocab) -> (B, 1) instead of (B,) so we can concat.
⬆ next_id = (B, 1) LongTensor of predicted next-token ids.
35ys = torch.cat([ys, next_id], dim=1)
Append the newly predicted column. ys grows from (B, T) to (B, T+1).
EXECUTION STATE
📚 torch.cat(tensors, dim) = Concatenate a list of tensors along the given dim. Sizes must match on all other dims.
⬇ arg: tensors = [ys, next_id] = (B, T) and (B, 1). Match on dim 0.
⬇ arg: dim = 1 = Concatenate along sequence length.
37if (next_id == eos_id).all(): break
Early-stop when EVERY batch element has emitted eos. .all() returns True only if all elements of the boolean tensor are True.
EXECUTION STATE
📚 .all() = Reduces a boolean tensor to a single Python bool: True iff every entry is True.
(next_id == eos_id) = (B, 1) BoolTensor — True where the prediction was eos.
40return ys
Final output: (B, <=max_len) LongTensor containing <sos> + generated ids.
EXECUTION STATE
⬆ return = In translation, you'd now strip <sos> (and everything past eos) and decode via the target tokenizer.
25 lines without explanation
1@torch.no_grad()2defgenerate(3 self,4 src: torch.Tensor,5 max_len:int,6 sos_id:int,7 eos_id:int,8 src_mask: Optional[torch.Tensor]=None,9)-> torch.Tensor:10"""Greedy autoregressive decoding. O(max_len * forward).
1112 For beam search, top-k / top-p sampling, and KV-cache acceleration,
13 see chapter 9.
14 """15 self.eval()16 device = src.device
17 B = src.size(0)1819# 1) Encode source ONCE. Memory is reused for every step.20 memory = self.encode(src, src_mask=src_mask)2122# 2) Start with <sos> for every batch element.23 ys = torch.full((B,1), sos_id, dtype=torch.long, device=device)2425for _ inrange(max_len -1):26 T = ys.size(1)27 causal = torch.triu(28 torch.ones(T, T, device=device)*float("-inf"),29 diagonal=1,30)31 hidden = self.decode(ys, memory,32 tgt_mask=causal,33 memory_mask=src_mask)34 step_logits = self.output_proj(hidden[:,-1])# (B, tgt_vocab)35 next_id = step_logits.argmax(dim=-1, keepdim=True)# (B, 1)36 ys = torch.cat([ys, next_id], dim=1)3738if(next_id == eos_id).all():39break4041return ys
Forward reference: greedy decoding is the cheapest, least-diverse strategy. Chapter 9 covers beam search (keep top-k partial hypotheses), sampling (temperature, top-k, top-p, typical decoding), and KV-caching (reuse past K,V so each step is O(T) instead of O(T2)).
Parameter Count
Let's sanity-check the model size for the ch13 Multi30k config: dmodel=512,H=8,dff=2048,N=6,Vsrc=7000,Vtgt=5500. We count each learnable piece.
Component
Formula
Parameters
One encoder layer (MHSA + FFN + 2 LN)
4*(d*d+d) + d*dff + dff + dff*d + d + 2*2*d
3,152,384
Encoder stack (N=6)
6 x encoder layer
18,914,304
One decoder layer (MHSA + Cross + FFN + 3 LN)
2 * 4*(d*d+d) + d*dff + dff + dff*d + d + 3*2*d
4,204,032
Decoder stack (N=6)
6 x decoder layer
25,224,192
Source embedding
V_src * d_model = 7000 * 512
3,584,000
Target embedding
V_tgt * d_model = 5500 * 512
2,816,000
Output projection
d_model * V_tgt + V_tgt
2,821,500
Total
sum of rows above
53,359,996 (~53.4M)
About 53M parameters — in the same ballpark as the original Vaswani base model (65M), a bit smaller because our vocabularies are smaller than WMT's. For reference, Multi30k is a low-resource dataset and you'll usually train a smaller variant (dmodel=256,N=3, closer to 10–15M parameters) to avoid overfitting.
Two memory-saving tricks cut this further:
Weight tyingWout=Etgt: removes the 2.82M output-projection parameters (the two matrices are shape (Vtgt,dmodel) and (dmodel,Vtgt) — each other's transpose).
Shared source+target embedding (when vocabularies are merged via joint BPE): further saves min(Vsrc,Vtgt)⋅dmodel parameters.
Interactive Visualization
The visualization below animates the full pipeline on a toy German→English pair, \text{"Der Hund läuft"} \to \text{"The dog runs"}. Watch the source tokens flow through the encoder stack, deposit as memory blocks, and then get queried by each decoder step through glowing cross-attention lines. Hover a line to see its (illustrative) attention weight.
Loading full Transformer visualization...
What to look for: (1) the encoder colors intensify as all four source tokens flow through; (2) memory blocks crystallize once the encoder finishes; (3) each target token's cross-attention lines concentrate on its aligned source word — "The"→"Der", "dog"→"Hund", "runs"→"läuft". These alignment patterns are the hallmark of a trained translation model.
In Modern Systems
Encoder-decoder: T5, BART, mBART, NLLB
The original Vaswani architecture is still the right default for sequence-to-sequence tasks where the full input is available up front and the output is a separate sequence. T5 (Raffel et al. 2019) reframes every NLP task as text-to-text and trains a shared encoder-decoder on a denoising objective. BART (Lewis et al. 2019) uses a similar shape for summarization and infilling. NLLB-200 (Meta 2022) scales mBART to 200-language translation — still an encoder-decoder, because the decoder wants to attend to the FULL source regardless of where it is in generation.
Decoder-only: GPT, Llama, Mistral, PaLM
Modern general-purpose LLMs drop the encoder entirely. They are single decoder stacks with masked self-attention only — no cross-attention, no separate vocabulary. Input and output are concatenated into one token stream, and training is plain next-token prediction on web-scale corpora. Inference splits into prefill (process the prompt in parallel) and decode (generate tokens one at a time with KV-caching). This architecture won for general LLMs because (1) training data is unlabeled text, no source/target split needed; (2) one stack is cheaper than two; (3) the prefill/decode split maps cleanly onto GPU systems.
Encoder-only: BERT, RoBERTa
Drop the decoder instead, keep the encoder. Useful for classification, retrieval, and sentence embeddings — tasks where you want a rich representation, not a generated sequence.
When each wins
Task family
Best architecture
Representative models
General LM / chat / code
Decoder-only
GPT-4, Llama-3, Mistral, Claude
Machine translation (esp. many-to-many)
Encoder-decoder
NLLB, M2M-100, mBART
Summarization, infilling
Encoder-decoder
BART, Flan-T5, PEGASUS
Classification, retrieval, embeddings
Encoder-only
BERT, RoBERTa, E5, BGE
Multimodal captioning, speech
Encoder-decoder
Whisper, BLIP-2, Pix2Struct
Recent research has revived encoder-decoder models for their efficiency on supervised seq2seq tasks — Flan-T5, UL2 (Google 2022), and PaLM-2's mixture-of-encoders suggest the split is alive, even as decoder-only dominates general chat.
Summary
The full Transformer is three stages: source → memory (encoder, once), target + memory → hidden (decoder, one step per output token at inference), hidden → logits (linear).
Training is one parallel forward pass over the whole target using teacher forcing + causal mask. Inference is an autoregressive loop; KV-caching (ch09) makes it affordable.
Source and target vocabularies can be separate (our Multi30k setup) or shared (joint BPE, common for same-script pairs). Weight tying Wout=Etgt removes a full embedding-sized matrix of parameters for free.
Modern taxonomy: encoder-decoder for seq2seq, decoder-only for general LLMs, encoder-only for classification/embeddings. Each shape exists because it fits a specific family of tasks.
Chapter Summary
Over the six sections of this chapter, we built the decoder from scratch:
Section 1 — Decoder Architecture Overview. The three-sublayer pattern (masked self-attn + cross-attn + FFN) and why the decoder exists.
Section 2 — Causal Masking. The triangular −∞ mask that lets teacher forcing pretend to be autoregressive.
Section 3 — Cross-Attention. Queries from the target side, keys and values from the encoder memory; how the decoder "looks at" the source.
Section 4 — Implementing TransformerDecoderLayer. The three sublayers wired together with Add&Norm, dropout placement, and the pre-norm vs post-norm debate.
Section 5 — Complete Transformer Decoder. Stacking N decoder layers with target embedding, positional encoding, and output projection. Weight tying and depth-scaling.
Section 6 — Full Encoder-Decoder Transformer (this section). Glue encoder + decoder into one module with a unified forward pass and a greedy generate loop.
You now have every piece of the Vaswani 2017 architecture implemented end-to-end. The rest of the book builds on this foundation: chapter 9 adds smarter decoding, chapter 10 adds training infrastructure, and chapter 13 walks through training this exact model on Multi30k until it translates real German sentences.
Exercises
Easy
1. Modify the Transformer constructor to accept separate num_enc_layers and num_dec_layers. Re-count the parameters for the Multi30k config with 4 encoder layers and 6 decoder layers.
Medium
2. Implement weight tying by setting self.output_proj.weight=self.tgt_embed.weight. Verify that training still runs (both tensors must have the same shape and dtype). Measure the parameter-count reduction and sanity-check the initial logits distribution.
3. Extend generate to return attention weights from the LAST decoder layer's cross-attention so you can visualize source↔target alignments after translation.
Hard
4. Replace the encoder-decoder with a decoder-only variant: concatenate src and tgt with a separator token and train next-token prediction on the concatenation. Measure translation quality on Multi30k — does the encoder pull its weight, or can a decoder-only beat it at this scale?
Next Section Preview
Chapter 9 picks up exactly where generate left off. We look at greedy decoding, beam search, sampling (temperature, top-k, top-p), and KV-caching — the collection of tricks that turn the naive O(T3) generation loop above into the real-time autoregressive systems powering modern LLMs.