In section 1 we turned raw text into a sequence of embedding vectors. Beautiful numbers, but still useless on their own — a sequence of vectors is not a verdict. The job of a text classifier is to read a whole sentence and emit a single symbol: "positive", "spam", "urgent", "French". The question of this section is: how do we compress a variable-length sequence of embeddings into one fixed-size prediction?
The recipe we are going to build is the canonical one that dominated NLP from about 2014 until the transformer era:
Each token is looked up in an embedding table and becomes a dense vector xt∈Rdembed.
An LSTM reads the sequence left to right, updating its hidden state ht and cell state ct at every step.
A pooling step collapses the sequence of hidden states (h0,…,hT−1) into a single summary vector.
A linear layer + softmax turns that summary into class probabilities.
Why This Matters: Almost every "classical" NLP benchmark — sentiment, topic, spam, language identification — was solved to super-human accuracy by exactly this architecture. Today it has been replaced at the top of the leaderboard by transformers, but it is still deployed in countless real products because it is cheap, fast, and good enough for the majority of tasks.
Anatomy of an LSTM Classifier
Conceptually, the classifier is a three-stage pipeline. Data flows left-to-right, dimensions change at each stage, and nothing flows back to earlier stages during inference.
Stage
What it does
Input → Output shape
Embedding
Look up the dense vector of each token id
(B, T) → (B, T, d_embed)
LSTM encoder
Carry information forward through time via gated recurrence
(B, T, d_embed) → (B, T, d_hidden)
Pooling
Collapse the time axis into one summary vector
(B, T, d_hidden) → (B, d_hidden)
Classifier head
Linear projection to class logits
(B, d_hidden) → (B, n_classes)
B is the batch size, T is the sequence length, d_embed is the embedding dimension, d_hidden is the LSTM hidden size, and n_classes is the number of labels. Every one of those numbers is a hyperparameter you will tune.
The only stage with time dependencies is the LSTM. Everything else is time-independent matrix algebra. This is why classifiers are much faster to train than, say, language models — there is no per-timestep loss.
Hover any box in the diagram below to see the learnable parameter matrices that live at that stage. The counts use the production defaults we will train in section 3: B=64,T=200,dembed=128,dhidden=256.
Loading shape-flow diagram…
The Mathematical Recipe, End to End
Let us write the whole classifier as a composition of functions. Given a sentence of token ids (id0,…,idT−1):
1. Embedding lookup
xt=E[idt]∈Rdembed. We met this operation in section 1 — a single row-lookup that runs in O(T⋅dembed) time.
2. LSTM recurrence
This is the gated recurrence introduced by Hochreiter & Schmidhuber (1997) that solves the vanishing-gradient problem of vanilla RNNs. For t=0,1,…,T−1, starting from h−1=c−1=0:
Gate pre-activations (one matmul computes all four gates): aiafagao=Wxxt+Whht−1+b
Gate activations: it=σ(ai),ft=σ(af),gt=tanh(ag),ot=σ(ao). The three sigmoids produce soft switches in (0,1); tanh produces the candidate content in (−1,1).
State updates — the one piece of the LSTM that is purely additive, and therefore the reason gradients do not vanish: ct=ft⊙ct−1+it⊙gt (cell) and ht=ot⊙tanh(ct) (hidden). ⊙ is element-wise multiplication.
3. Pooling
Choose one of three canonical recipes. The simplest is hˉ=hT−1 — just take the final hidden state. We will compare the alternatives in the next section.
4. Classifier head
z=Wclshˉ+bcls∈Rnclasses, then p=softmax(z). The predicted label is y^=argmaxkpk.
5. Loss
Training uses cross-entropy against the true label y: L=−logpy=−log∑kezkezy. Section 3 of this chapter will walk through the training loop; here we focus exclusively on the forward pass.
The whole model is a single differentiable function fθ with parameters θ=(E,Wx,Wh,b,Wcls,bcls). PyTorch's autograd will compute ∇θL for free once we call .backward() — but only because every step above is differentiable.
Before we roll the four gates over a real sentence, anchor each one in isolation. Drag any slider; the three sigmoids produce soft switches in (0,1), while tanh produces candidate content in (−1,1). The cell update ct=ft⊙ct−1+it⊙gt multiplies these outputs together elementwise.
Loading gate activations visualizer…
Now zoom out to the cell-state highway. Each of the dhidden channels has its own lane; memory from the previous step persists through the forget gate, and new content rides in via the input gate. Drag ft close to 1 and the original memory survives every step; drop it near 0 and the lane clears. This additive path is why LSTM gradients do not vanish.
Loading cell-state highway…
Pooling: Last, Mean, or Max Hidden State?
The pooling step is where a surprising amount of modelling judgment lives. Given the full trajectory H=(h0,h1,…,hT−1)∈RT×dhidden, we need one vector.
Pooling
Formula
When it works
When it hurts
Last
h̄ = h_{T-1}
Short sentences where the whole signal can survive the recurrence
Long sentences — early information has been overwritten
Mean
h̄ = (1/T) Σ h_t
When every token contributes useful evidence (topic classification)
When a few rare tokens carry all the signal — averaged out to noise
Max
h̄[k] = max_t h_t[k] (per-dim)
When salient features appear at unpredictable positions (sentiment, urgency)
When the magnitude of features is not a good proxy for importance
Attention
h̄ = Σ α_t h_t with learned α_t
When you can afford extra parameters and want position-aware weighting
Small datasets — α_t over-fits easily
A useful rule of thumb: if your classifier is doing worse than the bag-of-words baseline, switch from last pooling to mean + max concatenation. The improvement is often several points of F1, practically for free.
The comparison below pins the idea to the math. The left panel is a canned hidden-state trajectory H∈RT×dhidden; the right panel shows the pooled vector for whichever strategy you select. Notice how max-pool captures a spike that mean-pool dilutes, and how attention-pool biases toward the end of the sequence for this particular query.
Loading pooling comparison visualizer…
The attention pooling row is the crack through which transformers crawled into NLP. An attention pool is literally a single-head attention block with a learned query — once you accept that the sum ∑tαtht is a good idea, you are two steps away from dropping the LSTM entirely and letting every token attend to every other token.
Interactive Visualization: A Sentence Through an LSTM
The animation below lets you scrub through a sentence token by token and watch the hidden state ht evolve. Two sentences are pre-wired — an obviously positive one and an obviously negative one. The 4-bar histogram above the active cell is the current hidden vector; the yellow text below is the cell state ct. When you reach the last token, the classifier head appears on the right and softmax turns hT into a probability over positive and negative.
Loading LSTM classifier flow…
Three things to look for while you play with the slider:
Slow start. At t=0 both sentences share the same token ("I") and therefore the same tiny hidden activations. The network has almost no opinion yet.
The decisive token. At t=2 ("love" or "hate") the trajectories split dramatically — the sign of the hidden bars flips. This is the cell-state additive highway at work: one strong token can rewrite memory in a single step.
Noisy coda. The last two tokens ("this movie") are shared again, yet the two trajectories stay separated. The LSTM has committed. A transformer would let those later tokens re-weight the verdict via attention.
Building the Classifier From Scratch in NumPy
Before we hand the problem to PyTorch we are going to build the full classifier in pure NumPy. You will see that there is nothing magical inside nn.LSTM — just the four-gate recipe from the previous section, rolled over time with a Python for loop, followed by a single linear projection. Every matrix, every gate, every state update is visible.
The example uses dembed=2, dhidden=2, and a four-word sentence so every intermediate value fits comfortably on one line. Click any line number on the right to see what that line computes.
Forward pass — pure NumPy, every value traced
🐍lstm_classifier_numpy.py
Explanation(47)
Code(70)
1import numpy as np
NumPy gives us ndarrays, the @ matrix-multiply operator, broadcasting, and vectorized math (exp, tanh). Every line of the forward pass below runs as optimized C rather than slow Python loops.
as np = Universal convention; lets us write np.exp(), np.array(), @ for matmul, etc.
3# Section — tiny vocab + embedding matrix
We are building the smallest possible working classifier: a 6-token vocabulary, 2-dimensional embeddings, 2-dimensional hidden state. Small enough to trace by hand, rich enough to show every moving part of an LSTM classifier.
→ why <pad>? = Reserved id 0 so we can pad short sentences to a common length in a batch without introducing a real word.
5E = np.array([...])
The embedding matrix. Row k is the d_embed-dimensional vector assigned to token id k. In a trained model these rows are learned; here we hand-craft two polar opposites — 'great' points in the positive direction, 'awful' in the negative direction — so the forward pass yields an interpretable result.
EXECUTION STATE
📚 np.array = Converts a Python list of lists into a 2-D ndarray with dtype float64 (inferred from the floats).
⬆ E (6 × 2) =
d0 d1
<pad> 0.00 0.00
this 0.10 0.20
movie 0.40 0.30
is 0.20 0.10
great 0.90 0.70
awful -0.80 -0.60
→ shape = (vocab_size=6, d_embed=2) = 12 floats
→ intuition = 'great' sits in the upper-right, 'awful' in the lower-left — an LSTM that integrates embeddings over time will drift toward one pole or the other.
14# Section — the input sentence
We now pick a concrete sentence and push it through the pipeline. This is exactly the control flow used by any real classifier, just with tiny numbers.
15sentence = "this movie is great"
A 4-word sentence. All four tokens live in our vocabulary, so there are no <unk> tokens and no padding — the simplest possible case.
EXECUTION STATE
sentence = "this movie is great"
expected label = positive — the word 'great' is the emotional anchor.
16input_ids = np.array([vocab[w] for w in sentence.split()])
Split the sentence on whitespace, look up each word in the vocab, wrap the result as an ndarray so we can use fancy indexing in the next line.
EXECUTION STATE
📚 str.split() = No argument ⇒ split on any whitespace. 'this movie is great'.split() → ['this','movie','is','great'].
📚 np.array([...]) = Converts a Python list of ints into a 1-D ndarray. dtype defaults to int64.
⬆ input_ids = [1, 2, 3, 4]
→ shape = (T=4,)
17X = E[input_ids] # (T=4, d_embed=2)
NumPy fancy indexing: E[[1,2,3,4]] returns a new 4×2 matrix whose rows are E[1], E[2], E[3], E[4]. This is the O(T·d) lookup that replaces the O(T·V·d) one-hot-times-matrix multiplication from section 1.
EXECUTION STATE
📚 E[int_array] = NumPy fancy indexing — returns a new array whose i-th row is E[int_array[i]]. Works because E is 2-D and the index is 1-D of ints.
⬇ arg: input_ids = [1, 2, 3, 4]
⬆ result: X (4×2) =
d0 d1
this 0.10 0.20
movie 0.40 0.30
is 0.20 0.10
great 0.90 0.70
→ role = X[t] is the embedding handed to the LSTM at time step t. These are the only word-level features the recurrent network will ever see.
19# Section — LSTM weights
An LSTM has four gates per hidden unit: input (i), forget (f), cell-candidate (g), output (o). We stack their weight rows into one big matrix so the whole gate vector is computed in a single matmul.
20d_embed, d_hidden = 2, 2
Tuple unpacking — sets the two dimensions that govern every weight shape below. Real models use d_embed = 128..4096 and d_hidden = 128..4096; the math is identical.
→ 8 rows = 4 gates × 2 units = Every row decides one component of one gate. The 'cell' gate (g) is the only one with tanh; the others get sigmoid.
30W_h = np.zeros((4 * d_hidden, d_hidden))
The hidden→gates weight matrix. For clarity we zero it out in this toy — the cell therefore depends only on the current embedding, not on the previous hidden state. Real LSTMs never do this, but zeroing W_h makes the numbers easy to trace.
EXECUTION STATE
📚 np.zeros(shape) = Returns a new ndarray of the given shape, dtype float64, filled with 0.0.
⬆ W_h (8 × 2) =
all zeros
→ simplification = Setting W_h = 0 removes the h_{t-1} term from the gate equation, making each step depend only on x_t. You would never do this in a real model.
31b = np.zeros(4 * d_hidden)
Bias vector for all four gates stacked. Some frameworks initialise the forget-gate bias to 1 to encourage remembering early in training.
EXECUTION STATE
⬆ b = [0, 0, 0, 0, 0, 0, 0, 0]
→ tip = Jozefowicz et al. (2015) showed that initialising b_f ≈ 1 substantially improves LSTM training on long sequences.
33# Section — classifier head
After the LSTM has read the whole sentence, a tiny linear layer converts the final hidden state into class logits.
34W_cls = np.array([[-1.0, 1.2], [1.1, -0.9]])
A 2×2 classifier matrix — rows are classes, columns are hidden units. Row 0 fires when h[1] is large (positive anchor). Row 1 fires when h[0] is large (negative anchor).
EXECUTION STATE
⬆ W_cls (2 × 2) =
h[0] h[1]
positive -1.00 1.20
negative 1.10 -0.90
→ reading = Positive logit = -1.0·h[0] + 1.2·h[1]. A hidden state with small h[0] and large h[1] → strong positive prediction.
36b_cls = np.array([0.0, 0.0])
Two class biases, both zero — the classifier is perfectly neutral until the hidden state pushes it one way.
EXECUTION STATE
⬆ b_cls = [0.0, 0.0]
38def sigmoid(x) → np.ndarray
σ(x) = 1 / (1 + e^(-x)). Squashes any real number into (0, 1). Used for gates because a gate value has to behave like a soft switch: 0 means 'forget / block', 1 means 'keep / let through'.
EXECUTION STATE
⬇ input: x = Any ndarray of real numbers — gate pre-activations in our case.
One step of the LSTM recurrence. Takes the current embedding, the previous hidden and cell states, and returns the new hidden and cell states. Running this over a loop of length T is the entire forward pass.
EXECUTION STATE
⬇ input: x_t (d_embed,) = Current token's embedding — a 2-vector from X. At t=0 this is E['this'] = [0.10, 0.20].
⬇ input: h_prev (d_hidden,) = Hidden state from the previous step. Initialised to zeros at t=0.
⬇ input: c_prev (d_hidden,) = Cell state from the previous step — the LSTM's long-term memory line. Also zero at t=0.
The LSTM is the textbook cure for vanishing gradients in plain RNNs. Its four gates independently decide how much of the past to keep (f), how much of the present to absorb (i, g), and how much of the memory to expose (o).
43gates = W_x @ x_t + W_h @ h_prev + b
One matmul computes the pre-activations for all four gates at once. This fusing is why real LSTMs store W_i, W_f, W_g, W_o as a single stacked matrix — a single BLAS call replaces four separate ones.
⬆ g (2,) = Small positive values — early in the sentence the cell candidate drifts slightly positive.
47o = sigmoid(gates[6:8])
The output gate. Decides how much of the (nonlinearly squashed) cell to reveal as the hidden state.
EXECUTION STATE
→ at t=0 = sigmoid([0.00, 0.06]) ≈ [0.500, 0.515]
⬆ o (2,) = ~0.5 — the hidden state exposes about half the cell.
48c_t = f * c_prev + i * g
The cell-state update — the heart of the LSTM. Element-wise: keep f·c_prev of the old memory, add i·g of the new candidate. No matrix multiplication here, which is exactly what gives the LSTM its well-behaved gradients.
EXECUTION STATE
📚 * (Hadamard) = Element-wise multiplication on ndarrays. Different from @ which is matrix multiplication.
→ why additive? = The + instead of a multiply is the reason LSTMs don't suffer catastrophic vanishing: gradients flow along this additive 'highway' with identity derivative.
49h_t = o * np.tanh(c_t)
The hidden state is a gated, nonlinearly-squashed view of the cell. tanh keeps h_t bounded; the output gate controls which cell slots get exposed.
→ meaning = After reading only 'this', the network barely knows anything — the hidden state is close to zero. It will grow as stronger words like 'great' arrive.
50return h_t, c_t
The cell hands the next step two pieces of information: the publicly-visible hidden state h_t (used by the classifier and by the next cell) and the private long-term memory c_t.
EXECUTION STATE
⬆ (h_t, c_t) = At t=0: (≈[0.044, 0.028], ≈[0.089, 0.054]). Both are d_hidden-vectors.
52# Section — roll the LSTM over the sentence
Now we loop over all T time steps, threading (h, c) from each cell into the next. This is the one Python for-loop we cannot easily remove from a pure RNN/LSTM — and it is exactly the reason attention-based models dominate once sequences get long.
53T = X.shape[0]
Read the sequence length from the shape of X. For our sentence T = 4.
EXECUTION STATE
📚 ndarray.shape = Tuple of dimensions. X.shape = (4, 2), so X.shape[0] is 4 = sequence length.
⬆ T = 4
54h = np.zeros(d_hidden)
Initial hidden state — all zeros. Before reading any word, the network knows nothing.
EXECUTION STATE
⬆ h (initial) = [0.0, 0.0]
55c = np.zeros(d_hidden)
Initial cell state — also zeros. Some implementations learn this initial state; most use zeros.
EXECUTION STATE
⬆ c (initial) = [0.0, 0.0]
56hidden_history = []
We save every h_t so that we can later pool them (mean, max, or just take the last). For a 'last hidden state' classifier only h_T is strictly needed — keeping all of them makes switching pooling strategies trivial.
EXECUTION STATE
⬆ hidden_history = [] — empty Python list
57for t in range(T):
Iterate once per token. Every iteration consumes one row of X and produces one (h_t, c_t) pair.
Tuple-unpack: call our cell function and rebind h and c to the new values. This single Python statement is the entire recurrence.
EXECUTION STATE
📚 tuple unpacking = Python assigns the two returned ndarrays to h and c in one line — equivalent to result = lstm_step(...); h, c = result[0], result[1].
→ note = Because lstm_step returns fresh ndarrays, there is no aliasing bug. If it returned views, we would need .copy() to avoid mutating the previous state.
59hidden_history.append(h.copy())
We keep a snapshot of every h_t. The .copy() matters — without it, if a future version of lstm_step returned a view, every entry in the list could end up pointing at the same final array.
EXECUTION STATE
📚 ndarray.copy() = Returns a deep copy of the array contents — independent of the source.
61# Section — take final hidden state and classify
With the full sequence read, we pick the last hidden state and push it through a single linear layer + softmax. This is the simplest possible read-out (often called 'last pooling') and what nn.LSTM + nn.Linear does by default.
62h_T = hidden_history[-1]
Python's negative indexing: -1 is the last element, -2 the second to last, etc. Here it's h_{T-1} = h_3, the state after reading 'great'.
EXECUTION STATE
⬆ h_T = ≈ [0.345, 0.220]
→ why last? = Because an LSTM's hidden state is supposed to summarise everything seen so far. In practice this is optimistic — hence mean/max pooling or attention in the next section.
63logits = W_cls @ h_T + b_cls
A single linear layer turns the 2-dimensional hidden state into 2 class logits. No activation yet — softmax comes next.
EXECUTION STATE
→ computation = W_cls @ h_T = [-1.0·0.345 + 1.2·0.220, 1.1·0.345 - 0.9·0.220]
≈ [-0.345 + 0.264, 0.380 - 0.198]
≈ [-0.081, 0.182]
Wait — that would predict 'negative'. In a trained model with larger hidden activations the 'great' signal dominates; the weights here are deliberately small so you can trace the arithmetic. With proper training, h_T would be pushed toward [small, large], flipping the verdict.
The 'shift-by-max' trick for numerically stable softmax. Subtracting the maximum before exp() prevents overflow while leaving the softmax output unchanged (softmax is shift-invariant).
📚 ndarray.max() = No args → scalar max over all elements. logits.max() = 0.18 in our case.
⬆ unnormalised probs = ≈ [0.771, 1.000]
65probs /= probs.sum()
Divide each entry by the sum so the result is a valid probability distribution (non-negative, sums to 1).
EXECUTION STATE
📚 /= (in-place divide) = Equivalent to probs = probs / probs.sum() but reuses the same memory.
⬆ probs = ≈ [0.435, 0.565]
→ row sum = 0.435 + 0.565 = 1.000 ✓
66pred = int(np.argmax(probs))
Pick the index of the largest probability. int(...) converts the NumPy scalar to a plain Python int — useful when we later index a list with it.
EXECUTION STATE
📚 np.argmax = Returns the index of the maximum element. For [0.435, 0.565] it returns 1.
⬆ pred = 1
67label = ['positive', 'negative'][pred]
Indexing a Python list with the integer class id turns a number back into a human-readable string.
EXECUTION STATE
⬆ label = 'negative'
→ note = The untrained weights give the 'wrong' answer on purpose — we built the model; training it on real data (chapter 18 §3) is what shapes the weights so 'great' truly dominates.
69print("h_T :", h_T)
Show the final summary vector — the single 2-dim state that has to encode everything we read.
70print("logits :", logits)
Show the raw scores before softmax. The classifier is making its decision purely from these two numbers.
71print("probs :", probs)
Post-softmax probabilities — what you would report in a UI or use to compute cross-entropy during training.
72print("predicted:", label)
The final human-readable verdict. This is the one value a user of the classifier would see.
23 lines without explanation
1import numpy as np
23# --- 1. A tiny vocabulary and three embeddings (d_embed = 2) ---4vocab ={'<pad>':0,'this':1,'movie':2,'is':3,'great':4,'awful':5}5E = np.array([6[0.00,0.00],# <pad>7[0.10,0.20],# this8[0.40,0.30],# movie9[0.20,0.10],# is10[0.90,0.70],# great — pulls activations positive11[-0.80,-0.60],# awful — pulls activations negative12])1314# --- 2. Input sentence, tokenised + mapped to ids ---15sentence ="this movie is great"16input_ids = np.array([vocab[w]for w in sentence.split()])17X = E[input_ids]# (T=4, d_embed=2)1819# --- 3. LSTM parameters (d_hidden = 2) ---20d_embed, d_hidden =2,221# 4 gates stacked: [input, forget, cell, output] → rows = 4 * d_hidden22W_x = np.array([# (4*d_hidden, d_embed) = (8, 2)23[0.5,0.3],[0.1,-0.2],# input gate rows 0-124[0.4,0.6],[0.2,0.1],# forget gate rows 2-325[0.7,0.5],[0.3,0.4],# cell cand. rows 4-526[0.2,-0.1],[0.0,0.3],# output gate rows 6-727])28W_h = np.zeros((4* d_hidden, d_hidden))# simplified — no recurrent bias here29b = np.zeros(4* d_hidden)3031# --- 4. Classifier head: W_cls (2, d_hidden) + bias ---32W_cls = np.array([[-1.0,1.2],# logit for 'positive'33[1.1,-0.9]])# logit for 'negative'34b_cls = np.array([0.0,0.0])3536defsigmoid(x: np.ndarray)-> np.ndarray:37return1.0/(1.0+ np.exp(-x))3839deflstm_step(x_t: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray):40"""One LSTM cell step — the four-gate recipe."""41 gates = W_x @ x_t + W_h @ h_prev + b # (8,)42 i = sigmoid(gates[0:2])# input gate (2,)43 f = sigmoid(gates[2:4])# forget gate (2,)44 g = np.tanh(gates[4:6])# cell cand. (2,)45 o = sigmoid(gates[6:8])# output gate (2,)46 c_t = f * c_prev + i * g # cell state (2,)47 h_t = o * np.tanh(c_t)# hidden state (2,)48return h_t, c_t
4950# --- 5. Roll the LSTM over the sentence ---51T = X.shape[0]52h = np.zeros(d_hidden)53c = np.zeros(d_hidden)54hidden_history =[]55for t inrange(T):56 h, c = lstm_step(X[t], h, c)57 hidden_history.append(h.copy())5859# --- 6. Take the final hidden state and classify ---60h_T = hidden_history[-1]# (d_hidden,)61logits = W_cls @ h_T + b_cls # (2,)62probs = np.exp(logits - logits.max())63probs /= probs.sum()# softmax → (2,)64pred =int(np.argmax(probs))65label =['positive','negative'][pred]6667print("h_T :", h_T)68print("logits :", logits)69print("probs :", probs)70print("predicted:", label)
Quick Check: The untrained network above predicts "negative" for "this movie is great". That is not a bug — random weights give random verdicts. Training (chapter 18 §3) is what shapes Wx,Wh,Wcls so that hT aligns with the correct class direction.
The Same Model in PyTorch
Now we wrap the same recipe in an nn.Module. Three things to notice as you read the PyTorch version:
The shapes are identical to the NumPy version — we simply trade np.array for torch.Tensor.
The entire LSTM recurrence collapses into one line: self.lstm(x). PyTorch uses a fused cuDNN kernel that runs on GPU at roughly 100× the speed of the NumPy loop.
We return logits, not probabilities, because nn.CrossEntropyLoss expects raw logits and performs a numerically-stable log-softmax internally.
Forward pass — PyTorch, production-grade
🐍lstm_classifier_pytorch.py
Explanation(21)
Code(25)
1import torch
The core PyTorch package. Brings in torch.Tensor, autograd, CUDA support, and the numerical backend. Every model, weight, and gradient in this file is a torch.Tensor under the hood.
EXECUTION STATE
torch = Deep-learning framework. Provides tensors (like np.ndarray but with GPU support and automatic differentiation).
2import torch.nn as nn
The neural-network module — pre-built layers (Linear, LSTM, Embedding, ...) and the nn.Module base class that every model subclasses.
EXECUTION STATE
nn.Module = Base class. Tracks parameters automatically, supplies .to(device), .train()/.eval(), and .parameters().
nn.LSTM = Vectorised cuDNN implementation of the very loop you just wrote in NumPy.
nn.Embedding = The O(T·d) row-lookup layer from section 1.
nn.Linear = A learnable affine layer: y = x W.T + b.
3import torch.nn.functional as F
Functional versions of common operations (softmax, relu, cross_entropy). They have no parameters of their own — you pass everything explicitly.
EXECUTION STATE
F.softmax = Stateless function — σ(x_i)=exp(x_i)/Σ exp(x_j). Unlike nn.Softmax, there is no object to hold.
5class LSTMClassifier(nn.Module):
Every PyTorch model subclasses nn.Module. By doing so, any nn.* attributes we assign (embedding, lstm, classifier) are automatically registered — their parameters show up in .parameters(), move with .to(device), and are saved by .state_dict().
EXECUTION STATE
→ why subclass? = Gives you free parameter tracking, device placement, save/load, train/eval modes, and hooks.
Creates the learnable vocab_size × d_embed embedding matrix — the PyTorch equivalent of our hand-crafted E.
EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim, padding_idx) = A thin wrapper around a (num_embeddings, embedding_dim) weight matrix. Forward pass is a row lookup: output[i] = weight[input_ids[i]].
⬇ arg 1: vocab_size=6 = Number of rows in the weight matrix. Must be ≥ max(input_id) + 1.
⬇ arg 2: d_embed=2 = Number of columns.
⬇ arg 3: padding_idx=0 = Keeps row 0 at zeros. Also, gradients flowing into that row are zeroed during backprop, so padding never moves.
A single-layer LSTM. Internally contains W_x, W_h, and b — exactly the weights we wrote by hand — but executes the whole recurrence in a single cuDNN kernel on GPU.
EXECUTION STATE
📚 nn.LSTM(input_size, hidden_size, batch_first=...) = Builds a multi-layer LSTM. Without num_layers= this is a single layer. Parameters follow the four-gate stacking convention from the NumPy version.
⬇ arg 1: input_size=2 (d_embed) = Dimension of each x_t.
⬇ arg 2: hidden_size=2 (d_hidden) = Dimension of h_t and c_t.
⬇ arg 3: batch_first=True = Input/output shape is (batch, time, feature). Without this flag PyTorch defaults to (time, batch, feature) — the older academic convention. Almost all modern code uses batch_first=True.
→ parameter count = W_ih: (4·2, 2)=8·2=16 · W_hh: (4·2, 2)=16 · biases 2·(4·2)=16 → 48 floats total. Scales to 4·(d_embed+d_hidden+2)·d_hidden for real sizes.
The forward pass. Autograd records every operation inside, so loss.backward() will compute gradients end-to-end without you writing a single line of backward code.
EXECUTION STATE
⬇ input_ids = Integer tensor of shape (B, T). B = batch size, T = sequence length. Values must be valid vocabulary ids.
⬆ returns = Float tensor of shape (B, n_classes) — the raw logits.
→ why call forward via model(x)? = model(x) is __call__, which triggers hooks + autograd bookkeeping and then invokes .forward. Always call model(x), never model.forward(x) directly.
14x = self.embedding(input_ids) # (B, T, d_embed)
Row lookup: for each integer id in input_ids, pull the corresponding row of the embedding matrix. Adds a new last dimension d_embed.
The LSTM returns two things: the stacked hidden states for every time step (outputs) and the final (h, c) tuple. Tuple-unpacking lets us grab both in one line.
EXECUTION STATE
📚 self.lstm(x) → (outputs, (h_n, c_n)) = For a single-layer LSTM:
• outputs shape (B, T, d_hidden) — h_t stacked along time.
• h_n shape (num_layers, B, d_hidden) — final h.
• c_n shape (num_layers, B, d_hidden) — final c.
⬇ input: x = Float tensor (1, 4, 2).
⬆ outputs = Float tensor (1, 4, 2) — the full trajectory h_0, h_1, h_2, h_3.
⬆ h_n = Float tensor (1, 1, 2) — identical to outputs[:, -1, :].unsqueeze(0) for this one-layer LSTM.
⬆ c_n = Float tensor (1, 1, 2) — the final cell state.
16h_T = outputs[:, -1, :]
Take the last time step across the batch. The triple slice keeps the batch dim, picks time index -1, and keeps all hidden features.
EXECUTION STATE
📚 tensor[:, -1, :] = Standard slicing:
dim 0 (batch) → keep all
dim 1 (time) → take the last (-1)
dim 2 (feat) → keep all
⬆ h_T = Float tensor (B, d_hidden) = (1, 2). Shape has collapsed the time dimension.
→ alternatives = Mean pooling: outputs.mean(dim=1)
Max pooling: outputs.max(dim=1).values
Attention pooling: weighted sum with learned weights.
17logits = self.classifier(h_T) # (B, n_classes)
Run the pooled hidden state through the final linear layer. No activation yet — softmax happens outside the model (e.g., inside nn.CrossEntropyLoss), which is the standard PyTorch pattern.
EXECUTION STATE
⬇ input: h_T = (B, d_hidden) = (1, 2).
⬆ logits = (B, n_classes) = (1, 2). Raw, unnormalised scores. Can be any real number.
→ pro tip = Returning logits (not probs) from .forward is a PyTorch convention: nn.CrossEntropyLoss expects raw logits and applies log-softmax internally for numerical stability.
18return logits
Hand the logits back to the caller. The outside world decides whether to softmax them (for display) or feed them straight into the loss.
EXECUTION STATE
⬆ return = Float tensor (1, 2) — raw class scores.
21ids = torch.tensor([[1, 2, 3, 4]]) # 'this movie is great'
Build an input batch. The outer brackets make it a 2-D tensor of shape (1, 4) — one sentence of length four. nn.LSTM requires a batch dimension even for a single example.
EXECUTION STATE
📚 torch.tensor = Factory that copies data into a new tensor. dtype is inferred from the Python values — list of ints → int64.
Trigger the forward pass. Under the hood PyTorch builds an autograd graph so that loss.backward() can compute gradients later. For pure inference you would wrap this in torch.no_grad() to skip the graph.
EXECUTION STATE
⬇ input: ids = Int tensor (1, 4).
⬆ logits = Float tensor (1, 2), e.g. tensor([[0.07, -0.02]]) for a freshly-initialised model. Your numbers will differ — weights are random.
23probs = F.softmax(logits, dim=-1)
Convert logits into a probability distribution along the class dimension.
EXECUTION STATE
📚 F.softmax(x, dim) = softmax_i(x) = exp(x_i) / Σ_j exp(x_j). Numerically stable implementation — internally shifts by the max like our NumPy code did.
⬇ arg: dim=-1 = Normalise along the last axis (classes). For shape (B, C) this means each row sums to 1.
⬆ probs = Float tensor (1, 2). Each row sums to 1.000. With random weights you might see tensor([[0.52, 0.48]]); after training on real data the 'positive' column dominates for this sentence.
24print("logits:", logits)
Print the raw logits. Useful for diagnosing calibration issues and verifying that the shapes are what you expect before worrying about accuracy.
25print("probs :", probs)
Print the normalised probabilities. This is what a user-facing UI would show (e.g., a sentiment bar).
The parameter count of the PyTorch model is slightly larger than the NumPy one because nn.LSTM keeps separate biases bih and bhh (an artefact of the old cuDNN convention). Their sum is mathematically redundant but kept for API compatibility.
Making it Better: Bidirectional and Stacked LSTMs
The classifier above has a structural weakness: every hidden state ht only depends on tokens 0,1,…,t. The phrase "not great" is read in order — by the time the LSTM sees "great" it has already committed to a representation of "not". It cannot go back.
The fix is mechanically trivial: run a second LSTM in the reverse direction and concatenate the two hidden state streams. This is a bidirectional LSTM (BiLSTM).
ht→=LSTM→(x0:t), ht←=LSTM←(xt:T), and ht=[ht→;ht←]. The classifier head now receives a vector of size 2dhidden.
In PyTorch the change is a single keyword argument:
🐍bidirectional.py
Explanation(2)
Code(2)
1bidirectional=True
Adds a second LSTM that runs right-to-left. PyTorch stacks forward and backward hidden states along the last dimension, so outputs now has shape (B, T, 2·d_hidden).
EXECUTION STATE
⬆ outputs shape = (B, T, 2 · d_hidden)
→ parameter cost = Doubles LSTM parameters (two independent cells). The classifier head also doubles its input size.
2nn.Linear(2 * d_hidden, n_classes)
The classifier input is now the concatenation of the last forward and first backward hidden states, so its size doubles.
EXECUTION STATE
→ common pattern = For last-pooling: take outputs[:, -1, :d_hidden] (forward final) concatenated with outputs[:, 0, d_hidden:] (backward final).
Stacking is the other standard upgrade. Passing num_layers=2 to nn.LSTM runs a second LSTM on top of the first — the second layer reads the hidden states of the first as its inputs. Two or three layers is typical; more than that rarely helps and routinely hurts because of vanishing gradients across layers (not time).
A bidirectional stacked LSTM with max pooling was the workhorse text classifier from roughly 2015 to 2019. It is still a very strong baseline — if you are benchmarking a new model, implement this first.
The Fundamental Bottleneck
Take a second look at the interactive animation. At every time step the LSTM rewrites one fixed-size vector ht∈Rdhidden. Regardless of sentence length — 4 tokens or 4,000 — the entire memory of what was read must fit into those few hundred floats.
The information bottleneck: an LSTM classifier must squeeze every relevant detail of a variable-length sequence into a single fixed-size vector hT. Its capacity is constant; its input length is unbounded. Something has to give.
Empirically, three things go wrong once sequences get long:
Forgetting. Tokens read 200 steps ago have almost certainly been overwritten. Even with the additive cell highway, gradients of length-200 paths are numerically tiny during training.
Position blur. "The cat sat on the mat" and "The mat sat on the cat" produce different hT, but not by much — the LSTM has no explicit notion of position, only of order.
Sequential latency. Step t cannot begin until step t−1 has finished. GPUs hate this. For a transformer, all timesteps go through the matmul in parallel.
The original motivation for attention was exactly this bottleneck — Bahdanau et al.'s 2014 paper literally let the decoder "peek" back at the whole encoder trajectory instead of relying on one final state. Eight years and one transformer paper later, that same idea would replace the entire recurrence.
From LSTM Classifier to Transformer Classifier
Conceptually the transformer classifier is the same shape: embed, encode, pool, classify. The only stage that changes is the encoder.
Stage
LSTM classifier
Transformer classifier
Embedding
E[input_ids]
E[input_ids] + positional encoding
Encoder
LSTM — O(T) sequential steps, O(d_hidden) state
Stack of self-attention + MLP blocks — O(T²·d) compute per block, O(T·d) state
Pooling
Last, mean, max hidden state
[CLS] token's final hidden state (BERT-style) or mean over tokens
Head
Linear + softmax
Linear + softmax (unchanged)
Self-attention is what replaces the recurrence. Instead of threading a single hidden vector through time, every token attends to every other token via queries, keys, and values: Attn(Q,K,V)=softmax(QK⊤/dk)V. Crucially, all T positions are processed in parallel; there is no sequential loop. That is the superpower that made transformers feasible on modern GPUs.
Multi-head attention runs h such attention functions in parallel on different linear projections of Q, K, V. Each head can specialise — one looks at syntactic parents, another at coreference, another at position. The concatenated heads are projected back to dmodel. In an LSTM the hidden state mixes all of those roles into one vector; in multi-head attention they live in disjoint subspaces.
A useful mental picture: the LSTM hidden state is one summary; a transformer layer produces T summaries (one per token), each conditioned on the whole sequence. Pooling at the end is then a cheap aggregation rather than a compression.
Flash Attention and the Memory Story
Self-attention fixes the LSTM's information bottleneck, but introduces a new one: memory. The naive formulation materialises the T×T attention matrix S=QK⊤/dk. For a sequence of 8,192 tokens that is 67 million floats — 256 MB per head per layer, just for intermediates. On a 32-head, 80-layer model you run out of GPU HBM long before you run out of compute.
Flash Attention (Dao et al., 2022) is the fix that made long-context transformers practical. The key insight: we never actually need S in full — we only need its softmax output times V. So we tile Q, K, V into blocks that fit in GPU SRAM, compute the softmax incrementally (online softmax), and accumulate the output without ever writing the full T×T matrix to HBM.
Quantity
Naive attention
Flash Attention
HBM reads/writes
O(T² + T·d)
O(T·d) — fits in cache
FLOPs
O(T²·d)
O(T²·d) — identical
Peak memory
O(T²)
O(T·d)
Reported wall-clock speedup
1× (baseline)
FlashAttention v1: up to ~3× on 1k–4k sequences (Dao et al. 2022, arXiv:2205.14135). FlashAttention-2 roughly doubles v1's throughput at large T (Dao 2023, arXiv:2307.08691).
Flash Attention does not change the math — outputs are bit-identical to the naive implementation up to floating-point rounding. It is purely an IO-aware reorganisation of the memory accesses. The lesson: at scale, what kills you is not FLOPs but HBM bandwidth.
An LSTM classifier has no analogue of this problem because its state size is O(dhidden) independent of sequence length. That same fact is why LSTMs cannot reach back: constant state is a feature and the bug.
Positional Encodings and the KV-Cache
A transformer layer is permutation-equivariant — swap any two input tokens and the output swaps along with them. Self-attention sees a bag of vectors, not a sequence. That is a deal-breaker for language, where "dog bites man" is very different from "man bites dog".
Positional encodings are the patch. The original transformer added a fixed sinusoidal vector PE(t)∈Rdmodel to every token embedding. Modern variants — RoPE (rotary), ALiBi (additive linear biases), learned absolute, T5-relative — all share one goal: inject order information that an LSTM got for free from its recurrence.
Trade-off: an LSTM gets order structure from its sequential computation — free in bits, expensive in wall-clock time. A transformer gets parallelism for free but must pay a parameter budget for positional encodings. Rotary and ALiBi minimise that cost and generalise to sequences longer than training.
KV-cache — reclaiming incremental generation
There is one thing LSTMs do effortlessly that transformers struggle with: autoregressive generation. An LSTM, at decoding time, only needs to keep the last (ht,ct) pair in memory. A transformer, naively, re-computes attention over the entire prefix at every new token — an O(T2) blowup.
The KV-cache is the standard fix. Observe that once a token's key and value vectors are computed, they never change — they depend only on that token's embedding and position. So we cache them. At generation step t, we compute only the new token's qt,kt,vt, append kt,vt to the cache, and attend qt over the full cache.
Per-step cost
Without cache
With KV-cache
Compute
O(T²·d)
O(T·d)
Memory for cache
0
O(T · L · d) (L = # layers)
Bottleneck
Re-computing keys/values for every prefix token
HBM bandwidth reading the cache — now the dominant cost
For a 70-billion-parameter model at 8k context, the KV-cache can itself be 40 GB — which is why production systems use techniques like multi-query attention (one K/V head shared across all Q heads), grouped-query attention, and paged attention (a vLLM innovation) to keep the cache tractable.
The LSTM classifier never needs any of this. Its "cache" is two vectors, period. The whole KV-cache engineering stack exists to give transformers the same incremental-generation property LSTMs had trivially — without giving up parallelism during training.
The Scaling Story: Why We Left LSTMs Behind
If LSTMs are so elegant, why is almost every state-of-the-art language system a transformer? The answer is one word: parallelism.
Axis
LSTM
Transformer
Training-time parallelism
O(T) sequential — cannot be parallelised across time
O(1) sequential — all T tokens go through the same matmul
Compute per token
O(d_hidden²)
O(T · d_model + d_model²)
Memory per token (train)
O(d_hidden)
O(T · d_model) (for attention)
Long-range dependencies
Decay with path length
Constant-depth reach to any token
Hardware utilisation
~15–30% on A100
~50–70% with Flash Attention
On a modern accelerator, wall-clock time is dominated by how well the computation matches the hardware's parallelism. A transformer layer is essentially two giant matmuls per block; an A100 or H100 isdesigned to do exactly those matmuls at peak throughput. An LSTM layer is a tight sequential loop — the GPU's thousands of cores sit mostly idle waiting on a data dependency.
This is the deep reason the field migrated. Not that LSTMs cannot learn the task — for sentiment classification they still do, and well — but that once you want to scale to billions of parameters and hundreds of billions of tokens, an architecture that wastes 70% of your hardware is not tenable. Transformers were the architecture that hugged the hardware.
But LSTMs are not dead. On-device keyboard suggestion, streaming speech recognition, small-footprint sensor models — anywhere latency and energy matter more than peak quality — still use LSTMs or GRUs. And in 2023-2024 the SSM/Mamba family revived constant-state recurrence with transformer-level quality on several benchmarks. The story is not over.
Summary
In this section we built the canonical LSTM text classifier end-to-end:
Derived the four-stage architecture (embedding → LSTM → pool → linear + softmax) and wrote it as a single differentiable function.
Implemented the entire forward pass in pure NumPy, tracing every gate, cell update and hidden state to numerical values so nothing remains mysterious.
Wrapped the same recipe in a PyTorch nn.Module, contrasted its parameter layout, and explained every argument of nn.Embedding, nn.LSTM and nn.Linear.
Compared four pooling strategies and upgraded to the standard bidirectional, stacked workhorse.
Exposed the fixed-state bottleneck that drove the field toward attention, and traced the line from it to multi-head attention, positional encodings, Flash Attention and the KV-cache.
In section 3 we close the loop: feeding this classifier a labelled dataset, training with cross-entropy and Adam, monitoring F1, and watching the hidden-state trajectories in the visualization above actually learn to separate positive from negative.
References
Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation 9(8), 1735–1780.
Bahdanau, D., Cho, K. & Bengio, Y. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. ICLR 2015 / arXiv:1409.0473.
Vaswani, A. et al. (2017). Attention Is All You Need. NeurIPS 2017 / arXiv:1706.03762.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A. & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022 / arXiv:2205.14135.
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
Ainslie, J. et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023 / arXiv:2305.13245.