In section 2 we built a bidirectional LSTM that maps a sentence to two logits. Right now its parameters are random; it is wrong on roughly half of the validation reviews. Training is the process that turns that untrained architecture into a useful sentiment classifier.
Scale-up from the previous sections. Sections 18.1 and 18.2 used tiny toy sizes so every value could be traced by hand. This section uses production-scale defaults. The recipe is identical; only the numbers change:
§ 18.1 toy
§ 18.2 toy
§ 18.3 production
Vocab V
7
6
5 000
d_embed
4
2
128
d_hidden
—
2
256
Batch B
—
1
64
Seq length T
6
4
200
Bidirectional
—
no
yes
Classes
—
2
2
Every supervised learning problem — for an LSTM, a transformer, or a logistic regression — reduces to exactly one optimization:
θ⋆=argminθL(θ)=argminθN1∑i=1Nℓ(fθ(xi),yi)
Read this equation word by word. θ is every learnable number in the network (embeddings, LSTM gate weights, linear head, biases — about 1.8 million of them for our classifier). fθ is the model as a function of those parameters. ℓ is a per-example loss — how wrong the model is on a single review. The full loss L averages that across all N training examples, and we want the θ that makes it as small as possible.
The central claim of deep learning. If you pick a flexible enough fθ, a good enough ℓ, and enough data, then repeatedly nudging θ opposite the gradient of L is sufficient to discover parameters that generalize to unseen inputs. No explicit feature engineering, no rulebook.
Nothing in this equation cares whether fθ is a two-parameter logistic regression or a 70-billion-parameter transformer. The machinery — dataset, forward pass, loss, backward pass, optimizer step — is identical. Scale is what changes, and we'll trace every one of those scaling pressures down to Flash Attention and the KV-cache by the end of this section.
The Dataset and Mini-Batches
Our dataset is the tokenised, padded reviews we prepared in Section 18.1 — say 25 000 labelled movie reviews split into 20 000 for training, 2 500 for validation, and 2 500 held out for the final test. In PyTorch this is two Dataset objects wrapped by DataLoaders that yield shuffled mini-batches.
Why mini-batches instead of feeding all 20 000 reviews at once? Three reasons, in decreasing order of how often they matter:
Memory. A single batch of 64 reviews at length 200 with a BiLSTM is already ~100 MB of activations. The full 20 000-review gradient would need hundreds of gigabytes. A GPU cannot hold it.
Noise is useful. A mini-batch gradient is a noisy estimate of the full-dataset gradient. The noise helps the optimizer escape narrow, sharp minima that do not generalize.
Parallelism. 64 reviews fit perfectly on a modern GPU's SIMD units. Larger batches stop scaling when the layer is already at peak throughput.
Split
Size
Purpose
Shuffled?
Gradient?
train
20 000
drives parameter updates
every epoch
yes
val
2 500
detects overfitting, picks the best epoch
no
no
test
2 500
final, untouched, reported once
no
no
The animation below shows one epoch through the DataLoader. Each step pulls a batch of 64 from the shuffled pool; once the pool is empty, the epoch ends, the pool refills, and a fresh permutation is drawn. Turning shuffle off reveals why the default is on: consecutive batches become correlated and the gradient updates lose their SGD character.
Loading mini-batch animator…
The test set must be touched exactly once, at the end. Every decision you make while looking at a metric — picking an epoch, tuning a learning rate, adjusting dropout — contaminates that metric. The validation set is where those decisions happen; the test set is the final, independent audit.
An epoch is one full pass of the DataLoader over the training set — 20 000 / 64 ≈ 313 steps of SGD per epoch. A typical sentiment run completes in 20-30 epochs, or roughly 7 000 parameter updates.
The Loss Function: Cross-Entropy
For a classification problem the standard per-example loss is cross-entropy. For a single sample with true class y∈{1,…,C} and the model's predicted distribution pθ(c∣x) it is:
ℓCE(fθ(x),y)=−logpθ(y∣x)=−log∑c=1Cezcezy
where zc is the raw logit the model produced for class c. For the binary sentiment case C=2 and we can equivalently write ℓBCE=−[ylogp+(1−y)log(1−p)] where p=σ(z1−z0).
Why this particular loss?
Cross-entropy arises naturally from three different angles that all land on the same formula:
Maximum likelihood. Under the model's own distribution, the probability of observing the training labels is ∏ipθ(yi∣xi). Taking the negative log gives precisely the cross-entropy sum. Minimising cross-entropy is maximising the likelihood of the data.
Information theory. Cross-entropy measures the extra bits needed to encode a sample from the true distribution using the model's distribution. When the two match, H(p⋆,pθ)=H(p⋆) (the irreducible entropy of the data) and the excess term drops to zero.
Well-behaved gradients. Combined with softmax, its gradient collapses to the clean ∂ℓ/∂zc=pc−1[c=y]. No saturating derivatives, no vanishing gradients at wrong-but- confident predictions (MSE does vanish there).
Why not MSE? Squared error can be minimised over a sigmoid, but its gradient (p−y)⋅p(1−p) shrinks to zero when the model is confidently wrong — exactly when you'd want the biggest correction. Cross-entropy's gradient p−y does not, so learning never stalls on confident mistakes.
This gradient simplification is not a happy accident — it is the reason every classifier in this book, from logistic regression to transformer language models, uses the softmax/cross-entropy pair. We will exploit it directly in the NumPy trace below.
From SGD to Adam: Learning Rates That Adapt
Once we have L(θ) and its gradient, the update rule is deceptively simple:
θt+1=θt−η∇θL(θt)
This is stochastic gradient descent (SGD). A single scalar learning rate η controls step size for every parameter. It works, but on real networks it is fragile: one learning rate has to simultaneously suit parameters whose gradients vary by many orders of magnitude. The embedding matrix, the LSTM gate weights and the final bias all live on different scales.
Momentum — remember the past
First fix: accumulate an exponential moving average of the gradients themselves.
mt=β1mt−1+(1−β1)gt,θt+1=θt−ηmt
Now consistent gradient directions build up speed, while jittery coordinates (where gradients flip sign each step) average toward zero. This is the ball-rolling-down-a-hill intuition.
Per-parameter scaling — Adam
Adam (Kingma & Ba, 2014) adds a second EMA, this time of squared gradients, and uses it to scale each parameter's step individually:
mt=β1mt−1+(1−β1)gt,vt=β2vt−1+(1−β2)gt2
m^t=mt/(1−β1t),v^t=vt/(1−β2t)
θt+1=θt−ηm^t/(v^t+ε)
Where do these weird terms come from?
mt — momentum, the mean direction of recent gradients. Typical β1=0.9 gives an effective memory of about 10 steps.
vt — the mean-squared gradient per parameter. Large v means the parameter has been receiving big gradients; Adam shrinks its step. Small v means the parameter has been receiving tiny gradients; Adam grows its step.
m^,v^ — bias-corrected versions. Since m0=v0=0, the raw EMAs are biased toward zero in early steps; dividing by 1−βt undoes this.
ε≈10−8 — prevents division by zero when a parameter has received no gradient yet.
The effective step size for each parameter is η/v^t: small-gradient parameters get bigger steps, large-gradient parameters get smaller steps. Adam is the universal default optimizer of modern deep learning precisely because this per-parameter normalization makes the single global learning rate robust.
Interactive: SGD vs Adam on a Narrow Valley
To see why Adam's per-parameter scaling matters, here is a 3D loss landscape with two optimizers dropped at the same starting point. The loss is L(w1,w2)=21(w12+9w22) — a textbook narrow valley: nine times steeper along w2 than along w1. Real loss surfaces have many such direction-dependent curvatures; this one simply isolates the phenomenon.
Loading loss-surface visualizer…
Click Play. The red ball (SGD, fixed learning rate) is caught immediately by the steep w2 direction: its step is too large for that curvature, so it overshoots, lands on the opposite wall, overshoots again, and zig-zags its way toward the minimum. Meanwhile it makes only slow progress along the gentler w1 direction.
The blue ball (Adam) tracks vt separately for w1 and w2. Onw2 the running squared gradient is large, so the effective step η/v^ shrinks — no overshoot. On w1 it stays small, so Adam pushes harder. The result is a smooth, almost straight trajectory to the minimum.
This is why scaling laws papers for transformers almost universally use Adam (or its cousins AdamW, Lion, Adafactor). Transformer loss landscapes have extremely ill-conditioned curvature across their millions of parameters — a single global learning rate would be unusable without per-parameter adaptation.
The Five-Step Training Loop
Every training step — every one of the ~7 000 weight updates our LSTM will perform — is exactly the same five operations. Learn this rhythm once and it applies to every model you will ever train.
Forward pass.y^=fθ(x) — feed the batch through the model and record every intermediate activation.
Loss.ℓ=L(y^,y) — a single scalar tensor measuring current wrongness.
Backward pass.∇θℓ — the chain rule, reverse-mode, filling every parameter's .grad attribute.
Optimizer step.θ←Update(θ,∇θℓ) — Adam applies its moment-corrected rule.
Zero the gradients.optimizer.zero_grad() — clear the accumulators so the next .backward() starts from a clean slate.
Periodically — usually once per epoch — we pause training and run the full validation set through the forward pass only, recording loss and accuracy. That gives us the two curves in the visualization below.
Tracing LSTM Backprop in NumPy
The logistic-regression trace in the next section is the cleanest way to see a gradient derivation on one page, but you also need to know that the LSTM step from section 18.2 has a concrete backward pass. Below is a single cell traced line-by-line — forward, then backward — so you can see that ∂L/∂ct passes straight through the additive identity (the reason gradients do not vanish) while each gate's weight gradient comes from its own sigmoid or tanh derivative.
One LSTM step — forward and backward, fully traced
🐍lstm_backward.py
Explanation(49)
Code(59)
1import numpy as np
NumPy is the numerical-array library we rely on. For this trace it gives us element-wise +/-/*, matrix-multiply via @, np.tanh, np.exp (inside sigmoid), np.outer (for the rank-1 weight-gradient update), and np.concatenate (to stitch the four gate-gradient chunks into one vector).
EXECUTION STATE
numpy = Array library. Every elementwise op below is vectorised C, no Python loops.
3# --- From section 18.2: a single LSTM step, plus its gradient ---
Header comment. We re-use the exact toy from section 18.2 (d_embed = d_hidden = 2, four gates stacked into an 8-row weight matrix) and trace ONE step forward and then the corresponding gradient backward.
4d_embed, d_hidden = 2, 2
Tiny dimensions so every intermediate tensor fits on a single line. In production d_embed=128 and d_hidden=256 (the sizes the PyTorch block trains with), but the math is identical.
EXECUTION STATE
d_embed = 2 = Dimension of each token's embedding vector.
d_hidden = 2 = Dimension of the LSTM's hidden and cell state.
6# Inputs at time t (one step only, to keep values on one screen)
We are going to simulate what happens at a single timestep t. The three inputs are: x_t (the new embedding just read), h_prev (the hidden state produced by step t-1), c_prev (the cell state carried forward from step t-1).
7x_t = np.array([0.40, 0.30]) # embedding for "movie"
The embedding vector for the current token. Pretend this came out of nn.Embedding at step t. Arbitrary but fixed so every calculation is reproducible.
EXECUTION STATE
📚 np.array(list) = Creates a dense ndarray from a Python list. Dtype is float64 for a list of floats.
⬆ x_t (2,) = [0.40, 0.30]
8h_prev = np.array([0.10, -0.05])
The hidden state produced by the previous step. In a real model this would be h_{t-1} returned by the previous LSTM cell invocation.
EXECUTION STATE
⬆ h_prev (2,) = [0.10, -0.05]
9c_prev = np.array([0.20, 0.10])
The previous cell state. This is the memory that survives through the forget gate — the quantity whose additive update gives the LSTM its gradient-preserving highway.
EXECUTION STATE
⬆ c_prev (2,) = [0.20, 0.10]
11# Gate pre-activation matrices (reused from section 18.2)
Comment pointing back to the section-18.2 weights. W_x stacks the four gates' input-hidden weights, W_h their hidden-hidden weights, and b the biases. Zeroing W_h and b isolates the contribution of the current input for clarity.
12W_x = np.array([...])
The 8 × 2 weight matrix mapping the embedding vector x_t into the four gate pre-activations. Rows 0-1 produce a_i (input gate), 2-3 a_f (forget gate), 4-5 a_g (cell candidate), 6-7 a_o (output gate).
→ why 4*d_hidden rows? = One block of d_hidden rows per gate. Stacking them lets us compute all four pre-activations in a single matmul — PyTorch's nn.LSTM does exactly the same trick with weight_ih_l0.
13[ 0.5, 0.3], [ 0.1, -0.2], # input gate
Rows 0 and 1 of W_x. Together they drive a_i — the 2-dim input-gate pre-activation.
14[ 0.4, 0.6], [ 0.2, 0.1], # forget gate
Rows 2 and 3. These drive a_f — the forget gate's pre-activation.
15[ 0.7, 0.5], [ 0.3, 0.4], # cell candidate
Rows 4 and 5. These produce a_g — the pre-activation of the cell candidate g.
16[ 0.2, -0.1], [ 0.0, 0.3], # output gate
Rows 6 and 7. These drive a_o — the output gate's pre-activation.
17])
Closing bracket. W_x is now a single (8, 2) ndarray held in one contiguous block of memory.
18W_h = np.zeros((8, 2)) # simplified — no recurrent bias here
Zero the recurrent weights so the trace's backward pass is easy to read. A real LSTM has W_h drawn from small-random init; it is restored in the PyTorch block below.
EXECUTION STATE
📚 np.zeros((8, 2)) = Creates an (8, 2) ndarray filled with 0.0, dtype float64. Shape is passed as a tuple.
⬆ W_h (8, 2) =
All zeros.
19b = np.zeros(8)
Bias vector for the eight pre-activations. Zero so the only signal reaching the gates is W_x · x_t.
EXECUTION STATE
⬆ b (8,) = [0, 0, 0, 0, 0, 0, 0, 0]
21def sigmoid(z): return 1.0 / (1.0 + np.exp(-z))
The logistic sigmoid — squashes any real number into (0, 1). We will use it for the input / forget / output gates. Its derivative sigmoid(z) · (1 − sigmoid(z)) shows up in the backward pass.
EXECUTION STATE
📚 np.exp(x) = Element-wise e^x. Returns a new array of the same shape. Example: np.exp([-0.29]) = array([0.7483]).
⬆ returns = An ndarray of the same shape as z, each element in (0, 1).
23# --- Forward ---
Section header. Everything between here and the loss's gradient is the same four-gate recipe used in section 18.2.
24a = W_x @ x_t + W_h @ h_prev + b
Compute all eight gate pre-activations in one matmul. Because W_h and b are zero, this collapses to W_x @ x_t — a pure function of the current embedding.
EXECUTION STATE
📚 @ operator = NumPy matrix multiply. For (8, 2) @ (2,) → (8,): each row of W_x is dot-producted with x_t.
Slice the first two entries of a (the input-gate chunk) and squash through sigmoid. Result lives in (0, 1) and gates how much of the cell candidate to write.
Forget-gate activation. Same squash, applied to the next two-element chunk of a. Controls how much of c_prev survives.
EXECUTION STATE
⬇ a[2:4] = [0.34, 0.11]
⬆ f_g (2,) = [0.5842, 0.5275]
27g_c = np.tanh(a[4:6])
The cell-candidate activation. tanh is used here (not sigmoid) because the new content should be allowed to be either positive or negative.
EXECUTION STATE
📚 np.tanh(x) = Element-wise hyperbolic tangent, range (-1, 1). Derivative is 1 - tanh^2(x), used in the backward pass.
⬇ a[4:6] = [0.43, 0.24]
⬆ g_c (2,) = [0.4053, 0.2355]
28o_g = sigmoid(a[6:8])
Output-gate activation. Controls how much of the (tanh-squashed) cell state leaks out as the hidden state h_t.
EXECUTION STATE
⬇ a[6:8] = [0.05, 0.09]
⬆ o_g (2,) = [0.5125, 0.5225]
29c_t = f_g * c_prev + i_g * g_c
The cell-state update — the additive identity that keeps gradients alive across many steps. Elementwise: forget-fraction × old memory + input-fraction × new content.
→ why this line matters = This is the one arithmetic operation in the entire LSTM whose Jacobian with respect to c_prev is the diagonal of f_g — a number in (0,1). That avoids the exploding / vanishing chain you get when you multiply many arbitrary matrices together in an RNN.
30tanh_c = np.tanh(c_t)
Squash the raw cell state through tanh before emitting it as the hidden state. Keeps h_t bounded in (-1, 1) regardless of how large c_t has grown.
EXECUTION STATE
⬆ tanh_c (2,) = [0.3352, 0.1677]
31h_t = o_g * tanh_c
The output: multiply the output-gate mask elementwise into the tanh-squashed cell state. This h_t is what the next timestep consumes and what the classifier head will read at t = T-1.
EXECUTION STATE
→ channel 0 = h_t[0] = 0.5125 * 0.3352 = 0.1718
→ channel 1 = h_t[1] = 0.5225 * 0.1677 = 0.0876
⬆ h_t (2,) = [0.1718, 0.0876]
33# --- Upstream gradient: pretend dL/dh_t comes from the next layer ---
In a real training step, dL/dh_t is produced by whatever sits above the LSTM — the pooling stage or the classifier head. For this trace we hand-pick a plausible-looking upstream gradient so you can see how it flows backward into W_x, b, and h_prev.
34dL_dh = np.array([0.10, -0.20])
The upstream gradient. Positive in channel 0 means 'the loss wants h_t[0] to grow'; negative in channel 1 means 'the loss wants h_t[1] to shrink'.
EXECUTION STATE
⬆ dL_dh (2,) = [0.10, -0.20]
36# --- Manual backward through the LSTM cell ---
Section header. Now the chain rule: walk back through every op we did on the forward pass, multiplying by each op's local derivative.
This is the critical line. h_t = o_g * tanh(c_t), so partial h_t / partial c_t = o_g * (1 − tanh²(c_t)). Notice the additive structure: dc_t can flow straight through to c_prev without ever being multiplied by a matrix. That is the gradient-preserving highway Hochreiter & Schmidhuber designed.
EXECUTION STATE
📚 ** operator on ndarray = Elementwise power. tanh_c**2 squares every element. NumPy overloads **, +, - for elementwise math on arrays.
43# Through each gate's activation (sigmoid/tanh derivatives)
Now we back-propagate through the gate non-linearities. Two identities do all the work: d sigmoid(a)/da = sigmoid(a)(1 - sigmoid(a)), and d tanh(a)/da = 1 - tanh²(a).
44da_i = di_g * i_g * (1 - i_g)
Since i_g = sigmoid(a_i), its local derivative is i_g * (1 - i_g). Multiply by the incoming gradient di_g.
Stitch the four 2-dim gate gradients back into a single 8-dim vector that lines up with the rows of W_x. Now we can propagate through the one remaining op: the W_x @ x_t matmul.
EXECUTION STATE
📚 np.concatenate(list, axis=0) = Joins a sequence of ndarrays along an axis (default 0). Shapes must agree on every axis except the one being concatenated. Here: four (2,) chunks → one (8,) vector.
For a line y = W x, the chain rule gives dL/dW = (dL/dy) ⊗ xᵀ — a rank-one matrix whose (i, j) entry is dL/dy_i · x_j. This is the last step we need to update W_x.
51dW_x = np.outer(da, x_t) # shape (8, 2)
Compute every element of dL/dW_x in one line. np.outer builds the (8, 2) rank-one matrix whose row i is da[i] · x_t.
EXECUTION STATE
📚 np.outer(a, b) = Outer product: for 1-D inputs of lengths m and n, returns an (m, n) matrix with [i, j] = a[i] * b[j]. Example: np.outer([1,2], [3,4]) = [[3,4],[6,8]].
→ update rule = W_x ← W_x - lr * dW_x. With lr = 0.1 and a typical batch-averaged gradient, every one of the 16 W_x entries moves in whichever direction reduces L for this batch.
52db = da # shape (8,)
For a line y = W x + b, dL/db = dL/dy. No work to do — the bias gradient is just da.
The gradient that flows back to the previous timestep's hidden state. Since in this toy W_h is all zeros, dh_prev is also all zeros — gradient is cut off here. In a real LSTM W_h has non-trivial entries and dh_prev feeds the step-(t-1) backward pass.
EXECUTION STATE
📚 .T = Transpose. For a (8, 2) matrix, .T gives (2, 8). Then matmul with da (8,) yields a (2,) vector.
⬆ dh_prev (2,) = [0.0000, 0.0000]
→ dc_prev too = Similarly, dc_prev = dc_t * f_g = [0.0266, -0.0536]. It is implicit here — the real training loop threads this back into the step-(t-1) backward pass.
55print("h_t :", np.round(h_t, 4))
Forward-pass sanity print. Rounds h_t to 4 decimals for readability.
EXECUTION STATE
⬆ stdout = h_t : [0.1718 0.0876]
56print("c_t :", np.round(c_t, 4))
Inspecting c_t confirms the forget-gate scaling plus the cell-candidate addition landed where we expect.
Gradient flowing back into step t-1. Zero here only because W_h was zeroed. Swap W_h for an actual init and this vector becomes the upstream gradient of the previous step's backward pass — that is how gradients travel across long time chains.
EXECUTION STATE
⬆ stdout = dh_prev : [0. 0.]
10 lines without explanation
1import numpy as np
23# --- From section 18.2: a single LSTM step, plus its gradient ---4d_embed, d_hidden =2,256# Inputs at time t (one step only, to keep values on one screen)7x_t = np.array([0.40,0.30])# embedding for "movie"8h_prev = np.array([0.10,-0.05])9c_prev = np.array([0.20,0.10])1011# Gate pre-activation matrices (reused from section 18.2)12W_x = np.array([13[0.5,0.3],[0.1,-0.2],# input gate14[0.4,0.6],[0.2,0.1],# forget gate15[0.7,0.5],[0.3,0.4],# cell candidate16[0.2,-0.1],[0.0,0.3],# output gate17])18W_h = np.zeros((8,2))19b = np.zeros(8)2021defsigmoid(z):return1.0/(1.0+ np.exp(-z))2223# --- Forward ---24a = W_x @ x_t + W_h @ h_prev + b
25i_g = sigmoid(a[0:2])26f_g = sigmoid(a[2:4])27g_c = np.tanh(a[4:6])28o_g = sigmoid(a[6:8])29c_t = f_g * c_prev + i_g * g_c
30tanh_c = np.tanh(c_t)31h_t = o_g * tanh_c
3233# --- Upstream gradient: pretend dL/dh_t comes from the next layer ---34dL_dh = np.array([0.10,-0.20])3536# --- Manual backward through the LSTM cell ---37do_g = dL_dh * tanh_c # dL/do = dL/dh * tanh(c)38dc_t = dL_dh * o_g *(1- tanh_c**2)# dL/dc = dL/dh * o * (1-tanh^2(c))39df_g = dc_t * c_prev # dL/df40di_g = dc_t * g_c # dL/di41dg_c = dc_t * i_g # dL/dg4243# Through each gate's activation (sigmoid/tanh derivatives)44da_i = di_g * i_g *(1- i_g)45da_f = df_g * f_g *(1- f_g)46da_g = dg_c *(1- g_c**2)47da_o = do_g * o_g *(1- o_g)48da = np.concatenate([da_i, da_f, da_g, da_o])# (8,)4950# Gradient to W_x: outer product of da and x_t51dW_x = np.outer(da, x_t)# shape (8, 2)52db = da # shape (8,)53dh_prev = W_h.T @ da # shape (2,)5455print("h_t :", np.round(h_t,4))56print("c_t :", np.round(c_t,4))57print("dW_x[0,:] :", np.round(dW_x[0],4))# input-gate row 058print("db :", np.round(db,4))59print("dh_prev :", np.round(dh_prev,4))
One Training Step in Pure NumPy
Before we trust PyTorch's autograd, let us do every operation by hand. The model below is the simplest supervised learner there is — binary logistic regression with two features — but every line maps one-to-one onto what happens inside the LSTM's training step: a forward pass, a loss, an analytic gradient, and a weight update.
One SGD Step — Fully Traced
🐍train_step.py
Explanation(33)
Code(41)
1import numpy as np
NumPy provides ndarrays, the @ matmul operator, broadcasting, and vectorized math (exp, log, mean). Every operation below — the forward pass, the gradient, and the weight update — runs as optimized C code rather than slow Python loops.
as np = Universal alias so we write np.exp(), np.log(), np.mean(), np.array() throughout.
3# Section — four labeled training examples
We use a tiny 2-feature dataset so every number in the forward and backward pass can be traced by hand. Two examples are positive (y = 1), two are negative (y = 0).
4X = np.array([ ... ])
The design matrix X stacks all training examples as rows. Shape (N=4, d=2) means 4 samples, each a 2-dimensional feature vector. Think of these features as the output of an embedding layer or any earlier neural layer.
EXECUTION STATE
📚 np.array() = Creates an ndarray from a Python list-of-lists. The outer list becomes rows, each inner list becomes one row's columns.
-> purpose = Each row is one review's feature vector. Positive reviews have positive features; negative reviews have negative features. Logistic regression must discover this pattern.
5[ 0.5, 0.2 ] — 'great'
Feature vector for example 0: a positive review. Both features are positive. The model should learn to produce a high probability of class 1 here.
EXECUTION STATE
x0 = [ 0.50, 0.20 ]
y0 = 1 (positive)
6[ 0.8, 0.4 ] — 'amazing'
Feature vector for example 1: another positive review, features even more positive than example 0. This row is 'easy' — the model should be most confident here.
EXECUTION STATE
x1 = [ 0.80, 0.40 ]
y1 = 1 (positive)
7[ -0.3, -0.5 ] — 'awful'
Negative example. Both features negative, so after w·x the logit should be negative, pushing the sigmoid below 0.5 toward class 0.
EXECUTION STATE
x2 = [ -0.30, -0.50 ]
y2 = 0 (negative)
8[ -0.6, -0.2 ] — 'boring'
Another negative example. Features point the same direction as example 2, so the learned weights should produce negative logits for both.
EXECUTION STATE
x3 = [ -0.60, -0.20 ]
y3 = 0 (negative)
9]) — close the array
Closes the np.array() call. Result: X is now a single (4, 2) ndarray held in one contiguous memory block.
10y = np.array([1, 1, 0, 0])
The target labels. Same ordering as X's rows. We use 1 for positive and 0 for negative — this is what lets the sigmoid output be directly interpretable as P(y=1|x).
EXECUTION STATE
⬆ y (4,) = [1, 1, 0, 0]
-> purpose = Tells the loss function what the correct answer is for each row of X. Every supervised learning step needs these labels.
12# Section — model parameters (logistic regression)
The whole 'model' is just a weight vector w and a bias scalar b. Prediction: p(x) = sigmoid(w·x + b). Training = find w, b that make p(x_i) close to y_i for every i.
13w = np.array([0.1, -0.1])
Initial weight vector, one weight per feature. In real networks weights are initialized from a small random distribution (e.g., He or Xavier); we pick these tiny fixed values so the math is easy to follow.
EXECUTION STATE
⬆ w (2,) = [ 0.10, -0.10 ]
-> w[0] = How strongly feature 0 pushes the logit up (if positive) or down (if negative).
-> w[1] = Same for feature 1. Initially -0.1, meaning positive feature 1 values push the logit slightly toward class 0.
14b = 0.0
Bias — a constant offset added to every example's logit. Starts at 0 so there is no prior preference for either class.
EXECUTION STATE
b = 0.0
-> why bias? = Lets the decision boundary shift off the origin. Without a bias, the model can only draw lines through (0,0) in feature space.
15lr = 0.5
Learning rate — how big a step we take down the gradient. 0.5 is large for a real network but fine here because the loss surface is extremely well-conditioned with only 2 parameters.
EXECUTION STATE
lr = 0.5
-> too small = Training is slow; may get stuck in plateaus.
-> too large = Updates overshoot the minimum and the loss oscillates or diverges.
17def sigmoid(z) -> np.ndarray
Elementwise sigmoid activation. Squashes any real number into (0, 1), giving it a direct interpretation as a probability. Used here to convert raw logits into class-1 probabilities.
EXECUTION STATE
⬇ input: z = An ndarray of raw logits (any real numbers). Will be computed by z = X @ w + b.
⬆ returns = An ndarray of same shape as z, each element in (0, 1).
The canonical sigmoid formula. e is Euler's number ≈ 2.71828. Note that sigmoid(-z) = 1 - sigmoid(z) — useful symmetry for the BCE derivation below.
19return 1.0 / (1.0 + np.exp(-z))
Element-wise implementation. np.exp broadcasts over the array; the division and addition broadcast too. No loop, no matrix tricks — just the formula.
EXECUTION STATE
📚 np.exp() = NumPy's elementwise e^x. For every element z_i in z, returns exp(-z_i). Example: np.exp(-0.04) = 0.9608.
⬆ return: sigmoid(z) = For z = [0.03, 0.04, 0.02, -0.04]:
sigmoid = [0.5075, 0.5100, 0.5050, 0.4900]
21# Section — FORWARD PASS
The forward pass computes predictions and the loss. For a single SGD step we need: (1) logits z, (2) probabilities p = sigmoid(z), (3) loss = average BCE over the batch.
22z = X @ w + b
Compute logits for all 4 examples simultaneously via matrix multiplication. Broadcasting adds the scalar bias b to every element.
EXECUTION STATE
📚 @ operator = NumPy matrix multiply. For (4,2) @ (2,) -> (4,): each row of X is dot-producted with w.
Squash every logit into (0, 1). These are now the model's current probabilities of class 1 for each example. Close to 0.5 everywhere because the untrained weights are tiny.
-> interpretation = p[0] = 0.5075 means the model is 50.75% sure example 0 is positive. Target is 1, so there is room to push this much higher.
25# Comment — BCE loss averaged over the batch
The binary cross-entropy loss measures the distance between predicted probability p and true label y. Averaging over the batch gives us a scalar to minimize.
26eps = 1e-9
A tiny floor to avoid log(0) when p is exactly 0 or 1. Without eps, a confident wrong prediction would produce log(0) = -infinity and propagate NaNs everywhere.
EXECUTION STATE
eps = 1e-9 (= 0.000000001)
-> why 1e-9? = Small enough not to bias the loss meaningfully, large enough that float32 can represent it without underflow.
Binary cross-entropy per sample: -[y·log(p) + (1-y)·log(1-p)]. Averaged across the batch, this is one scalar that tells us how wrong the model currently is.
EXECUTION STATE
📚 np.log() = Elementwise natural logarithm. log(x) is negative for x in (0, 1), zero at x = 1. Example: np.log(0.5075) = -0.6783.
📚 np.mean() = Average of all elements. For a 1-D array of length n, sum(x) / n.
-> benchmark = Random prediction p = 0.5 gives loss = ln 2 ≈ 0.693. We're only slightly better than random — unsurprising with untrained weights.
29# Section — BACKWARD PASS (analytic gradient)
We need dLoss/dw and dLoss/db. Because BCE + sigmoid combine into a beautifully simple gradient, we can compute them in 3 lines instead of running autograd.
30# Comment — dL/dz = (p - y) / n
The magic of pairing BCE with sigmoid: the chain-rule product of dBCE/dp and dsigma/dz simplifies to the prediction error (p - y). Dividing by n keeps the gradient scale independent of batch size.
31grad_z = (p - y) / len(y)
The residual per example, averaged by batch size. Positive residuals mean 'we predicted too high' (need to push logit down); negative means 'predicted too low'.
EXECUTION STATE
📚 len() = Python built-in. For an ndarray it returns the size of axis 0. len(y) = 4 here — our batch size.
p - y = [0.5075-1, 0.5100-1, 0.5050-0, 0.4900-0] = [ -0.4925, -0.4900, +0.5050, +0.4900 ]
-> sign intuition = Negative grad_z for positive examples -> we need to INCREASE z there. Positive grad_z for negative examples -> we need to DECREASE z. Exactly what you'd hope.
32grad_w = X.T @ grad_z
Chain rule: dL/dw = dL/dz · dz/dw, and dz/dw = x. So each feature's gradient is X.T @ grad_z. This is the single most important line in all of supervised learning.
EXECUTION STATE
📚 .T (transpose) = Swaps axes. X is (4,2); X.T is (2,4). Matrix multiply (2,4) @ (4,) -> (2,).
-> interpretation = Both gradient components are negative, so SGD will INCREASE both w[0] and w[1]. That makes sense — positive features correlate with positive labels.
33grad_b = grad_z.sum()
For the bias, dz/db = 1 for every example, so the gradient is simply the sum (not mean, because we already averaged inside grad_z).
EXECUTION STATE
📚 .sum() = NumPy method: sum of all elements. For 1-D array [a,b,c,d] returns a+b+c+d.
-> very small = The two positive and two negative residuals nearly cancel, so bias gets almost no update this step — the initial b=0 is already close to optimal.
35# Section — PARAMETER UPDATE (one SGD step)
The actual 'learning' — one step of stochastic gradient descent. We move every parameter slightly in the OPPOSITE direction of its gradient.
36w = w - lr * grad_w
Elementwise update: each weight moves against its gradient, scaled by the learning rate. Because grad_w is negative, w grows.
-> direction = Weights rotate toward the positive-feature direction. The next forward pass will produce larger logits on positive examples and smaller on negative ones. The loss will drop.
37b = b - lr * grad_b
Same update rule for the bias. Tiny update because grad_b is tiny.
EXECUTION STATE
-> calculation = 0.0 - 0.5 * 0.00313 = -0.00156
⬆ new b = -0.001562
39print("loss :", round(loss, 4))
Print the batch loss. This is the single number we watch during training — it should decrease over many steps.
EXECUTION STATE
⬆ output = loss : 0.682
40print("new w :", np.round(w, 4))
Print the updated weights after one SGD step. These are the parameters we'll carry into the next step.
EXECUTION STATE
⬆ output = new w : [ 0.2355 -0.0194 ]
41print("new b :", round(b, 6))
Print the updated bias. Round to 6 places so we can see the small change that would otherwise print as 0.
EXECUTION STATE
⬆ output = new b : -0.001562
8 lines without explanation
1import numpy as np
23# --- 1. Four labeled training examples (2 features each) ---4X = np.array([5[0.5,0.2],# "great" -> y = 1 (positive)6[0.8,0.4],# "amazing" -> y = 1 (positive)7[-0.3,-0.5],# "awful" -> y = 0 (negative)8[-0.6,-0.2],# "boring" -> y = 0 (negative)9])10y = np.array([1,1,0,0])1112# --- 2. Model parameters (logistic regression) ---13w = np.array([0.1,-0.1])# starting weights14b =0.0# starting bias15lr =0.5# learning rate1617defsigmoid(z: np.ndarray)-> np.ndarray:18"""sigma(z) = 1 / (1 + e^-z) - squashes any real number into (0, 1)."""19return1.0/(1.0+ np.exp(-z))2021# --- 3. FORWARD PASS ---22z = X @ w + b # (4,) logits23p = sigmoid(z)# (4,) probabilities of class 12425# Binary cross-entropy loss (averaged over the batch)26eps =1e-927loss =-np.mean(y * np.log(p + eps)+(1- y)* np.log(1- p + eps))2829# --- 4. BACKWARD PASS (analytic gradient of BCE + sigmoid) ---30# For BCE + sigmoid, dL/dz collapses to (p - y) / n31grad_z =(p - y)/len(y)# (4,)32grad_w = X.T @ grad_z # (2,)33grad_b = grad_z.sum()# scalar3435# --- 5. PARAMETER UPDATE (one SGD step) ---36w = w - lr * grad_w
37b = b - lr * grad_b
3839print("loss :",round(loss,4))40print("new w :", np.round(w,4))41print("new b :",round(b,6))
Click any line of code and the left panel shows the state at that moment — input values, intermediate products, shapes, and the reason every argument is what it is. Notice how the BCE loss of 0.6820 is barely below the random-guess loss of ln2≈0.693. After the update it would be about 0.660 — a tiny improvement, but repeat this step a few hundred times and the model locks onto the boundary.
The simplification to remember. Because we paired BCE with sigmoid, the gradient of the loss with respect to the logit collapses to the clean residual (p−y)/n. Every modern classifier depends on this fact — including the transformer language head that predicts the next token with cross-entropy over a 50 000-word softmax. The only thing that changes at that scale is the dimension of the softmax.
The Full Training Loop in PyTorch
PyTorch replaces the NumPy gradient derivation with loss.backward() and the SGD update with optimizer.step(). Everything else — forward pass, loss, the inner loop over batches — looks identical. The full, production-grade template is below.
LSTM Classifier — Full Training Loop
🐍train_lstm.py
Explanation(56)
Code(66)
1import torch
PyTorch — the tensor library and autograd engine. Every array becomes a torch.Tensor with gradient tracking, every math op is recorded onto a computation graph, and .backward() walks that graph to compute all gradients.
The neural-network building blocks: Module, Linear, LSTM, Embedding, Dropout, CrossEntropyLoss — every layer we need is in here.
3from torch.utils.data import DataLoader
DataLoader wraps a Dataset and yields shuffled mini-batches. It also handles multi-process loading, collation, and pinning to GPU memory.
EXECUTION STATE
📚 DataLoader(ds, batch_size, shuffle) = Iterable that yields (x_batch, y_batch) tuples. Internally creates a RandomSampler (if shuffle=True) and calls __getitem__ on the dataset for each sampled index.
5# Section — classifier from Section 18.2
We re-use the bidirectional LSTM built in Section 2. This section focuses on training and evaluation; the architecture is unchanged.
6class SentimentLSTM(nn.Module):
Every trainable PyTorch model inherits from nn.Module. This gives us .train()/.eval() mode switches, .parameters() to feed to the optimizer, .state_dict() for saving, and automatic device handling.
EXECUTION STATE
📚 nn.Module = Base class: manages children modules, parameters, buffers. Its __call__ invokes forward() and runs hooks.
Constructor — creates every sub-layer exactly once, so the optimizer sees their parameters. Defaults are typical for a small-vocabulary sentiment task.
EXECUTION STATE
⬇ vocab_size = 5000 = Number of unique tokens the embedding layer knows. Larger = bigger embedding table.
⬇ d_embed = 128 = Dimension of each token embedding vector.
⬇ d_hidden = 256 = Hidden-state width of the LSTM cell. With bidir=True the output will be 2*d_hidden = 512 wide.
⬇ n_classes = 2 = Number of output logits (positive / negative).
8super().__init__()
Runs nn.Module's constructor. MUST be the first line — otherwise assigning sub-layers to self would fail silently (they wouldn't register).
A lookup table mapping each token id to a learnable d_embed-dimensional vector.
EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim, padding_idx) = Stores an (N, D) weight matrix. Forward: given integer ids, returns the corresponding rows. Under the hood: F.embedding(input, weight).
⬇ arg 1: vocab_size = 5000 = Rows of the embedding table.
⬇ arg 2: d_embed = 128 = Columns of the table — dimensionality of each token vector.
⬇ arg 3: padding_idx = 0 = Index 0 is treated specially: its embedding stays all zeros and receives no gradient. Lets batches contain variable-length sentences padded with 0.
A bidirectional LSTM processes the sequence left->right and right->left. Its per-step output concatenates the two, giving 2*d_hidden channels.
EXECUTION STATE
📚 nn.LSTM(input_size, hidden_size, batch_first, bidirectional) = Wraps the 4-gate cell from Section 18.2. Manages the recurrence over the time dimension. Returns (outputs, (h_n, c_n)).
⬇ input_size = 128 = Must equal d_embed.
⬇ hidden_size = 256 = Per-direction hidden width.
⬇ batch_first = True = Input shape is (B, T, d_embed) instead of PyTorch's historical default (T, B, d_embed). More intuitive and matches how DataLoader emits batches.
⬇ bidirectional = True = Doubles the parameter count but usually wins on classification tasks — each token sees left AND right context.
11self.drop = nn.Dropout(0.4)
Regularization: during training, randomly zeros 40% of features and rescales the rest by 1/(1-p). During eval it is a no-op. Critical for preventing over-fitting on small datasets.
EXECUTION STATE
📚 nn.Dropout(p) = Stochastic regularizer. p is the drop probability. Example: input [1,2,3,4], output might be [0,3.33,5,0] (each position independently).
⬇ p = 0.4 = 40% of the pooled features will be dropped per forward pass during training. A typical value for RNNs is 0.3-0.5.
12self.head = nn.Linear(2 * d_hidden, n_classes)
Final classification layer: linear projection from the 512-dim pooled vector to 2 logits.
EXECUTION STATE
📚 nn.Linear(in_features, out_features) = y = x @ W.T + b. Stores learnable W (out, in) and bias (out,).
⬇ in_features = 512 = 2 * d_hidden because the LSTM is bidirectional.
⬇ out_features = 2 = One logit per class. A softmax is applied implicitly inside CrossEntropyLoss.
14def forward(self, x):
The forward pass is invoked every time the model is called. Autograd automatically records each op for later .backward().
EXECUTION STATE
⬇ input: x = LongTensor of shape (B, T) — B sentences, each padded to length T, values are vocabulary ids.
⬆ returns = FloatTensor of shape (B, 2) — unnormalized logits over the two classes.
-> memory = For B=64, T=200: 64*200*128 float32 = ~6.55 MB per batch.
16out, _ = self.lstm(e)
Runs the BiLSTM over the full sequence. 'out' is the per-timestep output; we discard the final (h_n, c_n) tuple with '_'.
EXECUTION STATE
⬆ out shape = (B=64, T=200, 2*d_hidden=512)
-> discarded '_' = (h_n, c_n) — the final hidden/cell states. We don't need them because we mean-pool the full 'out' instead.
17pool = out.mean(dim=1)
Mean-pool over time: average all T hidden states per sentence. Converts (B, T, 512) -> (B, 512). The sequence is now one fixed-length vector per review, ready for classification.
EXECUTION STATE
📚 .mean(dim) = Collapses the given dimension by averaging. dim=1 targets the time axis T. Result's shape drops that axis.
⬇ dim = 1 = Axis 0 is batch, 1 is time, 2 is features. We want to pool over time, so dim=1.
⬆ pool shape = (64, 512)
18return self.head(self.drop(pool))
Apply dropout then the linear head. Dropout sits on the pooled features — the most important regularization point in this architecture.
EXECUTION STATE
⬆ return shape = (B=64, n_classes=2) — raw logits
-> no softmax here = CrossEntropyLoss applies log-softmax internally, so we return raw logits to avoid double-normalizing.
20# Section — instantiate model, loss, optimizer
Three lines to get from 'architecture' to 'trainable system'. Keep them together in a single scope so checkpoint-saving and resume logic can re-build identically.
⬆ model = A SentimentLSTM instance on the CPU (use .to('cuda') or .to('mps') for GPU).
22criterion = nn.CrossEntropyLoss()
The standard classification loss. Combines log-softmax + NLL in one numerically stable op.
EXECUTION STATE
📚 nn.CrossEntropyLoss() = Given logits (B, C) and targets (B,), computes -log(softmax(logits)[target]) averaged over B. Equivalent to F.cross_entropy(logits, targets).
-> why not softmax + NLL separately? = The fused op handles log(exp(x)) shifts internally -> avoids underflow/overflow when logits are large. A standard numerical-stability trick.
Adam optimizer with weight decay. It will track first and second moments (m, v) of every parameter gradient across steps.
EXECUTION STATE
📚 torch.optim.Adam(params, lr, betas, eps, weight_decay) = Maintains per-parameter m (EMA of gradients) and v (EMA of squared gradients), then updates theta -= lr * m_hat / (sqrt(v_hat) + eps).
⬇ arg 1: model.parameters() = Generator yielding every Parameter the optimizer should update. Excludes buffers like running stats of BatchNorm.
⬇ arg 2: lr = 1e-3 = The canonical Adam learning rate. Works for a remarkably wide range of models — reason Adam became the default.
⬇ arg 3: weight_decay = 1e-5 = L2 penalty on the weights. Adds weight_decay * w to the gradient each step, nudging all weights toward 0 and reducing overfitting.
25# Section — data loaders
We assume train_ds and val_ds were produced earlier from the tokenised, padded IMDb-style reviews built in Section 18.1.
Batches of 64 training examples, reshuffled every epoch. Shuffling is essential — without it, mini-batch gradients would be correlated across steps.
EXECUTION STATE
⬇ batch_size = 64 = Trade-off: larger batches give smoother gradients and higher GPU utilization; smaller batches add useful noise that can improve generalization.
⬇ shuffle = True = Randomly permutes dataset indices at the start of each epoch. The optimizer sees a different ordering every time — critical for SGD theory to hold.
27val_loader = DataLoader(val_ds, batch_size=256)
Validation uses a bigger batch (no backward pass, so memory cost is lower) and is not shuffled (ordering doesn't affect metrics).
EXECUTION STATE
⬇ batch_size = 256 = Fits easily in memory because eval skips the backward pass and stores no activations for autograd.
-> shuffle default = shuffle defaults to False. Deterministic order makes val metrics reproducible across runs.
29# Section — training loop with early stopping
The canonical pattern: for every epoch, run one train pass (backprop) and one val pass (no backprop), then decide whether to early-stop.
30best_val, bad, PATIENCE = float("inf"), 0, 4
Early-stopping bookkeeping. best_val tracks the lowest val loss so far; bad counts consecutive epochs without improvement; PATIENCE is how many we tolerate before stopping.
EXECUTION STATE
best_val = inf (so the first real loss will definitely be lower)
bad = 0
PATIENCE = 4 — stop if val loss fails to improve for 4 consecutive epochs.
32for epoch in range(1, 26):
Iterate for up to 25 epochs. early-stopping may cut this short if val loss plateaus.
34model.train()
Puts the model into training mode. Concretely: enables Dropout (otherwise it's a no-op) and uses running batch statistics in BatchNorm layers (we have none here but the habit matters).
EXECUTION STATE
📚 .train(mode=True) = Recursively sets self.training = True on this module and all submodules. Affects dropout, batchnorm, and anything else that checks self.training.
35train_loss = 0.0
Running tally of per-sample losses this epoch. We'll divide by dataset size at the end.
36for xb, yb in train_loader:
Each iteration yields one mini-batch: xb is LongTensor (B, T), yb is LongTensor (B,). DataLoader shuffled the indices and stacked them for us.
yb shape = (64,) — class index for each sentence (0 or 1).
37optim.zero_grad()
Clears the .grad attribute of every parameter. PyTorch ACCUMULATES gradients across .backward() calls, so if we skipped this line, the gradients from previous steps would pile up and ruin training.
EXECUTION STATE
📚 .zero_grad(set_to_none=True) = Walks every param_group and sets each parameter's .grad to None (default) or zero tensor. Faster than allocating a zero tensor every step.
-> classic bug = Forgetting this line: the loss looks fine at first but slowly goes crazy. Gradient accumulation is SOMETIMES desired (e.g., to simulate a larger batch), but only with explicit control.
38logits = model(xb)
Forward pass. Calls __call__ -> forward(xb). Autograd records every intermediate tensor on a fresh computation graph.
EXECUTION STATE
⬆ logits shape = (64, 2)
-> graph memory = All intermediate activations (embeddings, LSTM hidden states at every t, pooled vector, dropout mask) are retained in memory for .backward(). This is the 'activation memory' bill.
39loss = criterion(logits, yb)
Scalar tensor. Averaging over the batch happens inside CrossEntropyLoss (reduction='mean' is the default).
⬆ loss = 0-dim tensor holding a single float, e.g., tensor(0.631, grad_fn=<NllLossBackward>).
40loss.backward()
The most magical line. PyTorch walks the computation graph backward from the scalar loss, applying the chain rule at every node, and fills the .grad attribute of every leaf tensor (parameter).
EXECUTION STATE
📚 .backward() = Reverse-mode automatic differentiation. One backward pass populates gradients for EVERY parameter the loss depends on.
-> what's in memory now = Every parameter's .grad attribute has been set. The forward graph is freed (unless retain_graph=True).
Gradient clipping — if the total gradient L2 norm exceeds 1.0, rescale all gradients so it is exactly 1.0. Prevents exploding-gradient spikes, which RNNs are especially prone to.
EXECUTION STATE
📚 clip_grad_norm_(params, max_norm) = Computes total norm = sqrt(sum(g_i^2 for all i)); if > max_norm, multiplies every grad by max_norm / total_norm. In-place (trailing _).
⬇ arg 1: model.parameters() = Every learnable tensor whose gradient should be clipped together.
⬇ arg 2: max_norm = 1.0 = The cap on the combined gradient magnitude. 1.0 is the canonical RNN default.
42optim.step()
Adam reads each parameter's .grad, updates its moment estimates m and v, applies bias correction, and writes the new value back into the parameter tensor.
EXECUTION STATE
📚 optim.step() = For each param p:
m = b1*m + (1-b1)*p.grad
v = b2*v + (1-b2)*p.grad**2
m_hat = m / (1-b1**t); v_hat = v / (1-b2**t)
p -= lr * m_hat / (sqrt(v_hat) + eps)
-> state kept = optim maintains m and v for every parameter between steps. That's why we need a persistent optimizer object.
43train_loss += loss.item() * xb.size(0)
Convert the 0-dim tensor to a Python float with .item() (detaches from graph), then weight by the actual batch size — the last batch may be smaller than 64.
EXECUTION STATE
📚 .item() = Extracts a Python number from a 0-dim tensor. Fails for multi-element tensors. Also synchronizes CUDA if the tensor lives on a GPU.
📚 .size(0) = First-dim length — the actual batch size, which equals 64 except possibly for the final (partial) batch.
44train_loss /= len(train_ds)
After the epoch, divide by total samples to get average loss per example. This is what we plot on the loss curve.
47model.eval()
Switch out of training mode. Dropout now passes its input through unchanged. BatchNorm would switch to its stored running statistics.
EXECUTION STATE
📚 .eval() = Shortcut for .train(False). Affects layers that behave differently during inference (Dropout, BatchNorm, DropPath, etc.).
48val_loss, correct = 0.0, 0
Per-epoch validation tallies. We track both total loss (for the loss curve) and correct predictions (for accuracy).
49with torch.no_grad():
Context manager that disables autograd's computation-graph recording inside the block. Saves activation memory and speeds up the forward pass — essential for eval on large models.
EXECUTION STATE
📚 torch.no_grad() = Sets a thread-local flag; any op inside its 'with' block creates tensors with requires_grad=False and does not extend the graph.
-> savings = No activation tensors retained -> typically 30-60% less VRAM during eval vs train.
50for xb, yb in val_loader:
Iterate over the full validation set in order. No shuffling needed.
51logits = model(xb)
Forward pass on a validation batch. Because we are inside torch.no_grad() AND model.eval(), no graph is built and dropout is a pass-through.
Compute accuracy: the predicted class is whichever logit is larger (argmax). Compare against the true label and count matches.
EXECUTION STATE
📚 .argmax(dim) = Returns the index of the maximum along a given axis. For logits shape (B, 2) with dim=1, returns (B,) -- the predicted class per sample.
⬇ dim = 1 = Reduce over the class axis. Produces one prediction per row (per sample).
-> == yb = Elementwise equality -> Bool tensor of shape (B,). True where prediction was correct.
-> .sum().item() = .sum() casts Bool->Long and sums to a 0-dim tensor. .item() extracts the Python int count.
54val_loss /= len(val_ds)
Normalize to per-sample loss just like the training loss. Now we can fairly compare train_loss and val_loss — the gap is our overfitting signal.
55val_acc = correct / len(val_ds)
Total correct / total validation samples. In the chart above you can see this rise from ~0.55 to ~0.85 over 13 epochs, then plateau or dip slightly as the model overfits.
Single-line progress report. Include enough that you can read overfitting straight out of the console (large val-train gap).
58# Section — early stopping
Stop training when the val loss hasn't improved for PATIENCE epochs. Saves compute and prevents overfitting damage.
59if val_loss < best_val:
We just set a new best. Reset the bad counter and checkpoint this model.
60best_val, bad = val_loss, 0
Update the best-seen val loss and reset the patience counter.
61torch.save(model.state_dict(), "best.pt")
Persist just the parameter dict (not the Python object) to disk. When we later reload, we instantiate a fresh SentimentLSTM() and call .load_state_dict() — cleaner than pickling the entire object.
EXECUTION STATE
📚 model.state_dict() = OrderedDict mapping full parameter name -> tensor. E.g., 'lstm.weight_ih_l0' -> Tensor(...)
📚 torch.save(obj, path) = Pickles the object (handling tensors with efficient binary format) to disk.
62else:
No improvement this epoch. Advance the bad counter.
63bad += 1
One more epoch without improvement.
64if bad >= PATIENCE:
Run out of patience. Break the loop — training stops early.
65print(f"early stop at epoch {epoch}")
Logged so you know the model at 'best.pt' came from an earlier epoch, not the final one.
66break
Exits the for loop. The saved best.pt now holds the best-generalizing model — this is what you deploy.
10 lines without explanation
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
45# --- 1. Build the bidirectional LSTM classifier from Section 18.2 ---6classSentimentLSTM(nn.Module):7def__init__(self, vocab_size=5000, d_embed=128, d_hidden=256, n_classes=2):8super().__init__()9 self.embed = nn.Embedding(vocab_size, d_embed, padding_idx=0)10 self.lstm = nn.LSTM(d_embed, d_hidden, batch_first=True, bidirectional=True)11 self.drop = nn.Dropout(0.4)12 self.head = nn.Linear(2* d_hidden, n_classes)1314defforward(self, x):15 e = self.embed(x)# (B, T, d_embed)16 out, _ = self.lstm(e)# (B, T, 2*d_hidden)17 pool = out.mean(dim=1)# mean-pool over time18return self.head(self.drop(pool))# (B, n_classes)1920# --- 2. Instantiate model, loss, optimizer ---21model = SentimentLSTM()22criterion = nn.CrossEntropyLoss()23optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)2425# --- 3. Data loaders built earlier from tokenised reviews ---26train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)27val_loader = DataLoader(val_ds, batch_size=256)2829# --- 4. Training loop with early stopping ---30best_val, bad, PATIENCE =float("inf"),0,43132for epoch inrange(1,26):33# ---- train ----34 model.train()35 train_loss =0.036for xb, yb in train_loader:37 optim.zero_grad()# wipe previous gradients38 logits = model(xb)# forward39 loss = criterion(logits, yb)# scalar tensor40 loss.backward()# autograd41 torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)42 optim.step()# Adam update43 train_loss += loss.item()* xb.size(0)44 train_loss /=len(train_ds)4546# ---- validate ----47 model.eval()48 val_loss, correct =0.0,049with torch.no_grad():50for xb, yb in val_loader:51 logits = model(xb)52 val_loss += criterion(logits, yb).item()* xb.size(0)53 correct +=(logits.argmax(dim=1)== yb).sum().item()54 val_loss /=len(val_ds)55 val_acc = correct /len(val_ds)56print(f"ep {epoch:2d} | train {train_loss:.3f} | val {val_loss:.3f} | acc {val_acc:.3f}")5758# ---- early stopping ----59if val_loss < best_val:60 best_val, bad = val_loss,061 torch.save(model.state_dict(),"best.pt")62else:63 bad +=164if bad >= PATIENCE:65print(f"early stop at epoch {epoch}")66break
Five things in that template are worth internalising because you will write them in every training loop for the rest of your career:
model.train() / model.eval() — flip dropout and batchnorm on/off.
optim.zero_grad() at the top of every step — PyTorch accumulates gradients by default, which is almost never what you want.
with torch.no_grad(): around evaluation — disables the computation graph, halves memory, doubles speed.
clip_grad_norm_ — a cheap safety belt against exploding gradients that RNNs are notorious for.
torch.save(model.state_dict(), ...) on every best-so-far epoch — cheaper than re-training when something goes wrong.
Evaluation: Accuracy, Precision, Recall, F1
Accuracy — the fraction of predictions we got right — is the number you watch first, but it can be misleading. If 95% of your reviews are positive, a model that always predicts "positive" scores 95% accuracy without doing any work. We need richer metrics, and all of them derive from four counts organised in the confusion matrix.
Predicted positive
Predicted negative
Actual positive
TP — true positive
FN — false negative (missed)
Actual negative
FP — false positive (false alarm)
TN — true negative
From those four counts we derive:
Accuracy=TP+TN+FP+FNTP+TN — fraction correct. Dominated by the majority class in imbalanced datasets.
Precision=TP+FPTP — of the reviews we flagged positive, how many actually were? High precision = few false alarms.
Recall=TP+FNTP — of the actually-positive reviews, how many did we catch? High recall = few misses.
F1-score=P+R2⋅P⋅R — harmonic mean of precision and recall. Punishes models that trade one against the other. The default single-number metric for classification.
Why the harmonic mean? Because if either P or R is close to zero, the harmonic mean is also close to zero. The arithmetic mean (P+R)/2 would let a model cheat by being great at one metric and terrible at the other. F1 says you must do well on both.
Threshold matters — precision/recall is a knob
For a binary classifier, the prediction comes from p(x)>τ. Default τ=0.5, but you can slide it. Raise the threshold and precision goes up (we only call something positive if we are very sure), at the cost of recall. Lower it and recall rises but precision suffers. The right setting depends on which mistake costs more — an FP or an FN.
Computing the Metrics in Code
All four metrics drop out of a few NumPy boolean operations. You will rarely write this by hand in production — sklearn, torchmetrics, and HuggingFace's evaluate all ship one-liners — but understanding the counts behind the numbers is essential.
Precision, recall, F1 from a confusion matrix
🐍metrics.py
Explanation(18)
Code(22)
1import numpy as np
NumPy is Python's numerical array library. We use it here for vectorized boolean operations on the prediction arrays: comparing two ndarrays with == returns a boolean ndarray of the same shape, and the & operator does element-wise AND. This is how we count TP/FP/FN/TN in four lines instead of a Python for-loop.
EXECUTION STATE
numpy = Fast C-backed array library. Provides ndarray with element-wise == and &, plus .sum() which counts True values as 1.
as np = Universal alias — lets us write np.array() instead of numpy.array().
3# A model's predictions on a 20-sample validation slice.
Comment describing the data. We use 20 samples so every single prediction is visible in one glance — the tradeoff is that each metric is coarse (each TP/FP counts for 5% of accuracy). Real validation sets are thousands of samples.
4# 1 = positive review, 0 = negative review.
Label convention. By using 1 for positive, the cross-entropy loss in the earlier training section can be interpreted directly as -log p(y=1|x) for positives and -log p(y=0|x) for negatives.
The ground-truth labels for 20 reviews, in dataset order. These are the 'gold' answers — whatever the classifier predicts will be compared against these.
EXECUTION STATE
📚 np.array(list) = Converts a Python list into a dense ndarray. Dtype is inferred as int64 for a list of ints. Stored contiguously in memory, enabling fast vectorized ops.
The model's predicted labels for the same 20 reviews. Same length, same ordering. Notice positions 3 (true 1, pred 0), 6 (true 0, pred 1), 11 (true 1, pred 0), 15 (true 0, pred 1) — these are the mistakes we'll count below.
→ where they differ from y_true = i=3 (y_true=1, y_pred=0) → FN (missed a positive)
i=6 (y_true=0, y_pred=1) → FP (false alarm)
i=11 (y_true=1, y_pred=0) → FN
i=15 (y_true=0, y_pred=1) → FP
All other 16 positions match.
8# --- Confusion matrix counts ---
Section header. Every metric below (accuracy, precision, recall, F1) is derived from just four integer counts: TP, FP, FN, TN. NumPy's vectorized boolean ops compute each in one line.
9tp = int(((y_pred == 1) & (y_true == 1)).sum())
Count true positives: positions where the model predicted 1 AND the true label is also 1. Three NumPy operations fused in one line.
EXECUTION STATE
📚 y_pred == 1 = Element-wise equality: returns a boolean ndarray of the same shape, True where y_pred[i] == 1, False otherwise. No loop — NumPy broadcasts the scalar 1 across the whole array.
📚 & operator on bool arrays = Element-wise logical AND. Returns True at position i only if BOTH arrays have True at position i. Uses bitwise AND on the underlying bytes (True=1, False=0), so it's fully vectorized in C.
📚 .sum() on bool array = Treats True as 1 and False as 0, returns the integer total. Since the mask is True exactly at TP positions, this directly counts true positives.
📚 int(...) = Cast np.int64 → Python int. Cosmetic: makes the printed value look like `9` instead of `np.int64(9)` in older NumPy.
⬆ result: tp = 9 — the 9 positions where both y_pred and y_true are 1: i = 0, 2, 5, 8, 10, 13, 16, 17, 18.
10fp = int(((y_pred == 1) & (y_true == 0)).sum())
Count false positives: predicted 1 but the true label is 0. These are the 'false alarms' — reviews we called positive that were actually negative.
→ sanity check = tp + fp + fn + tn = 9 + 2 + 2 + 7 = 20 ✓ (every sample falls into exactly one cell of the confusion matrix).
14# --- Derived metrics ---
Section header. Accuracy, precision, recall, and F1 are all algebraic combinations of the four counts above. No new data is read from y_true / y_pred.
15accuracy = (tp + tn) / (tp + tn + fp + fn)
Fraction of predictions that are correct. Simple but blind to class imbalance: if 95% of your data is negative, always-predict-negative scores 0.95 without learning anything.
⬆ accuracy = 0.8 — 80% of the 20 predictions are correct.
16precision = tp / max(1, tp + fp)
Of all the samples we CALLED positive, how many actually were? High precision → few false alarms. The `max(1, ...)` guards against zero-division when the classifier never predicts positive.
EXECUTION STATE
📚 max(1, tp + fp) = Python builtin that returns the larger value. If tp + fp = 0 (classifier never said positive), the denominator would be zero — undefined. Using max(1, 0) = 1 keeps the division safe.
⬆ precision = 0.81818… — ~82% of positive predictions were correct.
17recall = tp / max(1, tp + fn)
Of all the samples that ARE positive, how many did we catch? High recall → few misses. Again guarded with max to avoid zero-division when the positive class is empty.
⬆ recall = 0.81818… — we caught ~82% of the actual positives.
→ precision ≈ recall here = Coincidence of symmetric error counts: FP = FN = 2. In general, threshold tuning pushes one up at the cost of the other — see the interactive threshold viz below.
Harmonic mean of precision and recall. Punishes models that trade one against the other — both must be high for F1 to be high. The `max(1e-9, ...)` floor is a tiny safety against both being zero.
EXECUTION STATE
📚 harmonic mean = F1 = 2PR / (P+R). Mathematically equivalent to 1 / mean(1/P, 1/R). If P=1.0 and R=0.0, arithmetic mean = 0.5 (misleading) but harmonic mean = 0.0 (honest).
First half of the metrics print. `:.3f` formats each float to exactly three decimal places. Python's implicit string concatenation merges this f-string with the next line's.
22 f"recall={recall:.3f} f1={f1:.3f}")
Second half. Implicit concatenation: Python joins adjacent string literals at compile time, so the two f-strings become one print argument.
Plug in one of the epochs from the learning-curve chart below. At epoch 13 the validation set had TP=139, FN=9, FP=72, TN=280, giving accuracy≈0.838, P≈0.659, R≈0.939, and F1≈0.774. The precision is dragged down by a large number of false positives — a signal that a higher decision threshold would trade a little recall for a lot of precision.
The knob matters. Below, drag the decision threshold τ and watch precision climb as you demand more confidence — at the cost of recall. The PR curve traces every operating point your classifier can reach without retraining.
Loading confusion-matrix visualizer…
Learning Curves and the Overfitting Story
Plotting train loss, val loss and accuracy against epoch is the single most important diagnostic you can run on a model. Three failure modes reveal themselves here before they poison your test set.
Loading training-curves chart…
Drag the epoch slider across a realistic 25-epoch run of this exact LSTM classifier. Read the curves this way:
Epochs 1-6 — learning. Both losses drop quickly. Train and val curves hug each other. The model is discovering the sentiment signal.
Epochs 7-13 — consolidation. Rate of improvement slows. Val loss is still decreasing. This is your "sweet spot" region, and the green best line marks its minimum.
Epochs 14-25 — overfitting. Train loss keeps dropping, but val loss starts rising and val accuracy stalls or decays. The model is memorising idiosyncrasies of the training set — rare reviewer phrases, punctuation patterns — that do not generalise.
The rule is blunt: deploy the model from the best val-loss epoch, not the last one. That is the whole job of the early-stopping code block in the PyTorch template.
Underfitting shows up as the opposite pattern: both losses stay high and the curves plateau early. Remedy: make the model bigger, lower regularization, train longer, or feed richer features.
Regularization: Dropout, Weight Decay, Early Stopping
Overfitting is the phenomenon that the model's capacity exceeds what the data actually supports. Every practical regularization technique is a constraint that shrinks effective capacity without shrinking raw parameter count.
Technique
What it does
Where in our model
How to tune
Dropout
Randomly zero a fraction p of activations during training; rescale surviving ones by 1/(1-p)
nn.Dropout(0.4) before the linear head
Start 0.3-0.5 for RNNs, 0.1-0.3 for transformers
Weight decay (L2)
Add lambda||w||^2 to the loss — equivalently, w -= eta*lambda*w each step
weight_decay=1e-5 in Adam
1e-4 to 1e-6; decouple it (AdamW) for transformers
Early stopping
Halt training once val loss stops improving
PATIENCE = 4 in the loop
Patience 3-10; restore best checkpoint
Data augmentation
Perturb inputs in label-preserving ways to grow effective dataset
Random token dropout / backtranslation (optional)
Domain-specific; test gain per augmentation
Gradient clipping
Cap the total gradient norm at a constant
clip_grad_norm_(params, 1.0)
0.5-5.0; always safe for RNNs
Dropout — the Bayesian intuition
Dropout can be read as training an exponentially large ensemble of thinned sub-networks that share weights. With p = 0.4 on a 512-dim pooled vector, there are 2512 possible masks. At test time we average by leaving all units on and scaling — a free approximation to the ensemble's mean prediction. The ensemble lens explains why dropout reduces variance without hurting expressivity.
Dropout was introduced by Srivastava et al. 2014 (JMLR 15) and first tuned for RNN-specific use by Zaremba, Sutskever & Vinyals 2014 (arXiv:1409.2329), which is why our RNN dropout rates (0.3–0.5) sit higher than the transformer-era defaults (0.1–0.3).
Weight decay — keep the weights small
Adding 2λ∥θ∥2 to the loss and differentiating gives θ←θ−η(g+λθ)=(1−ηλ)θ−ηg. Every step multiplies the weight by (1−ηλ)<1, pulling it slightly toward zero. Large weights need a reason to survive.
For Adam, prefer AdamW. Vanilla Adam scales the decay term by 1/v^ along with the gradient, which is not the same as L2 regularization. AdamW decouples the decay and applies it directly to θ. This matters more than people realise — transformer training recipes essentially always use AdamW for this reason (Loshchilov & Hutter 2019, Decoupled Weight Decay Regularization, ICLR 2019 / arXiv:1711.05101).
Early stopping — the free regularizer
Early stopping is often the simplest regularizer to deploy. It costs nothing at train time, has one hyperparameter (patience), and works regardless of architecture. The caveat: you need a validation set you trust.
Scaling the Same Recipe to Transformers
Everything we have written for the LSTM — train mode, mini-batch loop, cross-entropy, Adam, early stopping — transfers to transformers essentially unchanged. What changes is scale and the bottlenecks that scale exposes. Below is the training anatomy of a modern transformer language model at roughly 1 B parameters, side-by-side with our ~1.8 M-parameter LSTM.
Axis
Our LSTM classifier
1B-parameter transformer LM
Parameters
~1.8 M
~1 B
Dataset
~20 k examples
~200 B tokens
Sequence length T
~200
2 048 - 131 072
Loss
CrossEntropy (2 classes)
CrossEntropy (50 k-token softmax, per position)
Optimizer
Adam, lr 1e-3
AdamW, lr ~3e-4 with warmup+cosine decay
Batch size
64 samples
0.5 - 4 M tokens (via gradient accumulation)
Epochs
~25
Usually < 1 full pass (tokens seen once)
Precision
fp32
bf16 / fp8 with mixed-precision
Hardware
1 CPU / GPU
1 000+ GPUs with 3-D parallelism
Dominant cost
LSTM recurrence
Attention (quadratic in T) + matmuls
Every line of that table is the same five-step training loop wearing a different coat. A few extensions deserve mention because they are direct consequences of scaling the recipe:
Learning-rate schedule. Transformers warmup η linearly for the first 1-4 k steps (because v^ needs time to converge to a good estimate), then cosine-decay it over the remainder of training. This came out of the original transformer paper and survives to this day.
Gradient accumulation. When the desired batch (e.g., 4 M tokens) does not fit in GPU memory, weaccumulate gradients over several micro-batches before calling optimizer.step(). Same math, lower memory peak.
Mixed-precision training. Forward and backward in bf16 or fp8, but Adam's moments in fp32. Roughly halves memory and 2x speed on A100/H100, with no accuracy penalty if done correctly.
Parallelism. Data parallel across GPUs, tensor parallel across GPUs, pipeline parallel across GPUs, and ZeRO sharding across GPUs — all orthogonal strategies to scale one training step across a cluster.
Flash Attention: Fixing the Training-Time Memory Wall
If you take our five-step loop and naively swap the LSTM for a transformer, the first wall you hit is not compute — it is activation memory. Self-attention produces a score matrix of shape (T,T) at every layer. For T=8192 that is 81922=6.7×107 floats per head, and we need to keep it around for the backward pass. A 32-layer, 32-head model at 8 k context chews through hundreds of gigabytes of activation memory before anyone trains anything.
Flash Attention (Dao et al., 2022) solves this by rewriting self-attention as a tiled, online softmax. It never materialises the full (T,T) score matrix in high-bandwidth memory (HBM). Instead it:
Loads small tiles of Q,K,V into SRAM.
Computes each tile's contribution to the softmax on the fly, updating a running maximum and running sum using the numerically-stable online softmax algorithm.
Writes only the attention output back to HBM — which is (T,d), not (T,T).
For the backward pass, recomputes the tiles on the fly rather than caching them — trading cheap SRAM compute for expensive HBM storage.
Naive attention
Flash Attention
Peak activation memory (fwd)
O(T²·h) in HBM
O(T·d) in HBM + tiles in SRAM
HBM reads / writes per layer
O(T²)
O(T²/M) where M = SRAM tile size
Speedup vs naive at T = 8 k
1x
2-4x forward, 3-6x backward
What it enables
T ~ 2 k before OOM
T ~ 128 k fits with room for bigger batches
Flash Attention is the single most important implementation-level paper of the transformer era. It is not a new model, a new loss, or a new algorithm — it is a memory-aware reorganisation of the exact math we wrote above. Yet it unlocked long-context training on consumer-grade accelerators and is the reason modern LLMs reason over book-length contexts during training.
The LSTM classifier in this chapter never needs any of this. Its hidden-state tensor at one timestep is a single dhidden-dim vector, not a quadratic score matrix. This is the memory upside of recurrence — cheap in VRAM, expensive in wall-clock time. Flash Attention is the engineering answer to the question "how do we keep the matrix-multiply speed of a transformer while giving back the memory efficiency of an LSTM?"
KV-Cache: Why Evaluation Looks Different from Training
There is one more place where our LSTM and a transformer diverge sharply, and it shows up at evaluation / inference time — the same torch.no_grad() block in our training loop.
During LSTM eval, decoding a new token means one cheap gate computation that consumes the previous hidden and cell state and produces the next. The only "cache" we carry forward is those two dhidden-dim vectors.
During transformer eval, generating token T+1 means re-running self-attention where the new token's query attends to all previous tokens' keys and values. Recomputing those keys and values from scratch at every step would make generation O(T2) per token.
The KV-cache saves K and V matrices for every past token, layer-by-layer, so each new step only computes one new Q/K/V triple and attends against the growing cache. The price is memory:
Model
T
Cache size per token (fp16)
Total cache
Our LSTM
200
~1 KB (h + c)
~1 KB (state is constant)
GPT-2 small, 12 layers, d=768
1 024
~36 KB (across all layers)
~36 MB
Llama-3 70B (GQA, 80 layers, 8 KV heads × 128)
8 192
~320 KB
~2.5 GB
Llama-3 70B (GQA, same config)
131 072
~320 KB
~40 GB
Same 70B without GQA (MHA, 64 KV heads)
131 072
~2.5 MB
~320 GB (infeasible)
Under grouped-query attention the per-token KV cache is 2⋅L⋅nkv_heads⋅dhead⋅bytes. Llama-3 70B (Grattafiori et al. 2024, arXiv:2407.21783) uses L=80, nkv_heads=8, and dhead=128, so at fp16 a single token consumes 2×80×8×128×2=327,680 bytes ≈ 320 KB. Without GQA — i.e. using all 64 query heads for K and V too — that figure would be 64/8=8× bigger, which is the gap that makes 128k-context serving feasible vs. impossible on a single H100.
At 70 B scale the KV-cache alone can exceed the model weights in memory. This drove the invention of several variants:
Multi-Query Attention (MQA) — share one K/V head across all Q heads. Cuts KV-cache size by h (head count), at a small quality cost.
Grouped-Query Attention (GQA) — compromise: g K/V heads shared across h Q heads (Llama 2/3 use g=8).
Paged Attention (vLLM) — the OS's virtual memory trick applied to attention: the KV-cache is stored in fixed-size pages so batches of different sequence lengths don't waste space.
Evaluation for our LSTM is trivial — one forward pass, no cache. Evaluation for a 70 B transformer is a systems-engineering effort in its own right. The entire serving stack at every big-model company is essentially an answer to "how do we make torch.no_grad() fast for a transformer?"
Summary
In this section we closed the loop on the sentiment classifier we have been building for three sections:
Stated training as one optimization problem over parameters and showed that the same equation drives everything from logistic regression to a 1 B-parameter transformer.
Derived cross-entropy from three angles (likelihood, information theory, gradient friendliness) and showed why it pairs so cleanly with sigmoid / softmax.
Walked from SGD to momentum to Adam, visualized the difference with an interactive 3D loss surface, and explained why Adam (and AdamW) is the default optimizer of modern deep learning.
Traced one complete training step in pure NumPy — forward, loss, backward, update — with every numeric value computed. Then wrote the same recipe as an industrial PyTorch training loop, annotating every function call and argument.
Defined accuracy, precision, recall, F1 and the confusion matrix from first principles, and scrubbed through a realistic 25-epoch run to see overfitting appear exactly where the theory predicted.
Cataloged the regularization toolkit — dropout, weight decay, early stopping, gradient clipping, augmentation — and showed which one each hyperparameter in our training loop corresponds to.
Extended the same five-step loop to billion-parameter transformers and identified what scaling breaks: activation memory (Flash Attention fixes it) and inference memory (the KV-cache and its MQA / GQA / paged variants fix it).
The line from our modest LSTM to modern LLMs is remarkably straight. Swap the architecture, swap the dataset, swap the hardware — but the training loop, the loss, the optimizer, the metrics and even the overfitting diagnostic are the same five steps we just wrote. Everything else is engineering in service of making those five steps fit onto the hardware.
References
Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization. arXiv:1412.6980.
Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I. & Salakhutdinov, R. (2014). Dropout: A Simple Way to Prevent Neural Networks from Overfitting. JMLR 15.
Zaremba, W., Sutskever, I. & Vinyals, O. (2014). Recurrent Neural Network Regularization. arXiv:1409.2329.
Loshchilov, I. & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019 / arXiv:1711.05101.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A. & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022 / arXiv:2205.14135.
Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F. & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023 / arXiv:2305.13245.
Kwon, W. et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023 / arXiv:2309.06180.
Grattafiori, A. et al. (2024). The Llama 3 Herd of Models. arXiv:2407.21783.