Chapter 18
22 min read
Section 58 of 65

Building an LSTM Classifier

Text Classification Project

The Goal: From a Sentence to a Verdict

In section 1 we turned raw text into a sequence of embedding vectors. Beautiful numbers, but still useless on their own — a sequence of vectors is not a verdict. The job of a text classifier is to read a whole sentence and emit a single symbol: "positive", "spam", "urgent", "French". The question of this section is: how do we compress a variable-length sequence of embeddings into one fixed-size prediction?

The recipe we are going to build is the canonical one that dominated NLP from about 2014 until the transformer era:

  1. Each token is looked up in an embedding table and becomes a dense vector xtRdembedx_t \in \mathbb{R}^{d_\text{embed}}.
  2. An LSTM reads the sequence left to right, updating its hidden state hth_t and cell state ctc_t at every step.
  3. A pooling step collapses the sequence of hidden states (h0,,hT1)(h_0, \ldots, h_{T-1}) into a single summary vector.
  4. A linear layer + softmax turns that summary into class probabilities.
Why This Matters: Almost every "classical" NLP benchmark — sentiment, topic, spam, language identification — was solved to super-human accuracy by exactly this architecture. Today it has been replaced at the top of the leaderboard by transformers, but it is still deployed in countless real products because it is cheap, fast, and good enough for the majority of tasks.

Anatomy of an LSTM Classifier

Conceptually, the classifier is a three-stage pipeline. Data flows left-to-right, dimensions change at each stage, and nothing flows back to earlier stages during inference.

StageWhat it doesInput → Output shape
EmbeddingLook up the dense vector of each token id(B, T) → (B, T, d_embed)
LSTM encoderCarry information forward through time via gated recurrence(B, T, d_embed) → (B, T, d_hidden)
PoolingCollapse the time axis into one summary vector(B, T, d_hidden) → (B, d_hidden)
Classifier headLinear projection to class logits(B, d_hidden) → (B, n_classes)

B is the batch size, T is the sequence length, d_embed is the embedding dimension, d_hidden is the LSTM hidden size, and n_classes is the number of labels. Every one of those numbers is a hyperparameter you will tune.

The only stage with time dependencies is the LSTM. Everything else is time-independent matrix algebra. This is why classifiers are much faster to train than, say, language models — there is no per-timestep loss.

Hover any box in the diagram below to see the learnable parameter matrices that live at that stage. The counts use the production defaults we will train in section 3: B=64, T=200, dembed=128, dhidden=256B=64,\ T=200,\ d_\text{embed}=128,\ d_\text{hidden}=256.

Loading shape-flow diagram…

The Mathematical Recipe, End to End

Let us write the whole classifier as a composition of functions. Given a sentence of token ids (id0,,idT1)(\text{id}_0, \ldots, \text{id}_{T-1}):

1. Embedding lookup

xt=E[idt]Rdembedx_t = E[\text{id}_t] \in \mathbb{R}^{d_\text{embed}}. We met this operation in section 1 — a single row-lookup that runs in O(Tdembed)O(T \cdot d_\text{embed}) time.

2. LSTM recurrence

This is the gated recurrence introduced by Hochreiter & Schmidhuber (1997) that solves the vanishing-gradient problem of vanilla RNNs. For t=0,1,,T1t = 0, 1, \ldots, T-1, starting from h1=c1=0h_{-1} = c_{-1} = 0:

Gate pre-activations (one matmul computes all four gates): [aiafagao]=Wxxt+Whht1+b\begin{bmatrix} a_i \\ a_f \\ a_g \\ a_o \end{bmatrix} = W_x \, x_t + W_h \, h_{t-1} + b

Gate activations: it=σ(ai),  ft=σ(af),  gt=tanh(ag),  ot=σ(ao)i_t = \sigma(a_i),\; f_t = \sigma(a_f),\; g_t = \tanh(a_g),\; o_t = \sigma(a_o). The three sigmoids produce soft switches in (0,1)(0, 1); tanh produces the candidate content in (1,1)(-1, 1).

State updates — the one piece of the LSTM that is purely additive, and therefore the reason gradients do not vanish: ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t (cell) and ht=ottanh(ct)h_t = o_t \odot \tanh(c_t) (hidden). \odot is element-wise multiplication.

3. Pooling

Choose one of three canonical recipes. The simplest is hˉ=hT1\bar{h} = h_{T-1} — just take the final hidden state. We will compare the alternatives in the next section.

4. Classifier head

z=Wclshˉ+bclsRnclassesz = W_\text{cls} \, \bar{h} + b_\text{cls} \in \mathbb{R}^{n_\text{classes}}, then p=softmax(z)p = \text{softmax}(z). The predicted label is y^=argmaxkpk\hat{y} = \arg\max_k p_k.

5. Loss

Training uses cross-entropy against the true label yy: L=logpy=logezykezk\mathcal{L} = -\log p_y = -\log \frac{e^{z_y}}{\sum_k e^{z_k}}. Section 3 of this chapter will walk through the training loop; here we focus exclusively on the forward pass.

The whole model is a single differentiable function fθf_\theta with parameters θ=(E,Wx,Wh,b,Wcls,bcls)\theta = (E, W_x, W_h, b, W_\text{cls}, b_\text{cls}). PyTorch's autograd will compute θL\nabla_\theta \mathcal{L} for free once we call .backward() — but only because every step above is differentiable.

Before we roll the four gates over a real sentence, anchor each one in isolation. Drag any slider; the three sigmoids produce soft switches in (0,1)(0, 1), while tanh produces candidate content in (1,1)(-1, 1). The cell update ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t multiplies these outputs together elementwise.

Loading gate activations visualizer…

Now zoom out to the cell-state highway. Each of the dhiddend_\text{hidden} channels has its own lane; memory from the previous step persists through the forget gate, and new content rides in via the input gate. Drag ftf_t close to 1 and the original memory survives every step; drop it near 0 and the lane clears. This additive path is why LSTM gradients do not vanish.

Loading cell-state highway…

Pooling: Last, Mean, or Max Hidden State?

The pooling step is where a surprising amount of modelling judgment lives. Given the full trajectory H=(h0,h1,,hT1)RT×dhiddenH = (h_0, h_1, \ldots, h_{T-1}) \in \mathbb{R}^{T \times d_\text{hidden}}, we need one vector.

PoolingFormulaWhen it worksWhen it hurts
Lasth̄ = h_{T-1}Short sentences where the whole signal can survive the recurrenceLong sentences — early information has been overwritten
Meanh̄ = (1/T) Σ h_tWhen every token contributes useful evidence (topic classification)When a few rare tokens carry all the signal — averaged out to noise
Maxh̄[k] = max_t h_t[k] (per-dim)When salient features appear at unpredictable positions (sentiment, urgency)When the magnitude of features is not a good proxy for importance
Attentionh̄ = Σ α_t h_t with learned α_tWhen you can afford extra parameters and want position-aware weightingSmall datasets — α_t over-fits easily
A useful rule of thumb: if your classifier is doing worse than the bag-of-words baseline, switch from last pooling to mean + max concatenation. The improvement is often several points of F1, practically for free.

The comparison below pins the idea to the math. The left panel is a canned hidden-state trajectory HRT×dhiddenH \in \mathbb{R}^{T \times d_\text{hidden}}; the right panel shows the pooled vector for whichever strategy you select. Notice how max-pool captures a spike that mean-pool dilutes, and how attention-pool biases toward the end of the sequence for this particular query.

Loading pooling comparison visualizer…

The attention pooling row is the crack through which transformers crawled into NLP. An attention pool is literally a single-head attention block with a learned query — once you accept that the sum tαtht\sum_t \alpha_t h_t is a good idea, you are two steps away from dropping the LSTM entirely and letting every token attend to every other token.


Interactive Visualization: A Sentence Through an LSTM

The animation below lets you scrub through a sentence token by token and watch the hidden state hth_t evolve. Two sentences are pre-wired — an obviously positive one and an obviously negative one. The 4-bar histogram above the active cell is the current hidden vector; the yellow text below is the cell state ctc_t. When you reach the last token, the classifier head appears on the right and softmax turns hTh_T into a probability over positive and negative.

Loading LSTM classifier flow…

Three things to look for while you play with the slider:

  • Slow start. At t=0t = 0 both sentences share the same token ("I") and therefore the same tiny hidden activations. The network has almost no opinion yet.
  • The decisive token. At t=2t = 2 ("love" or "hate") the trajectories split dramatically — the sign of the hidden bars flips. This is the cell-state additive highway at work: one strong token can rewrite memory in a single step.
  • Noisy coda. The last two tokens ("this movie") are shared again, yet the two trajectories stay separated. The LSTM has committed. A transformer would let those later tokens re-weight the verdict via attention.

Building the Classifier From Scratch in NumPy

Before we hand the problem to PyTorch we are going to build the full classifier in pure NumPy. You will see that there is nothing magical inside nn.LSTM — just the four-gate recipe from the previous section, rolled over time with a Python for loop, followed by a single linear projection. Every matrix, every gate, every state update is visible.

The example uses dembed=2d_\text{embed} = 2, dhidden=2d_\text{hidden} = 2, and a four-word sentence so every intermediate value fits comfortably on one line. Click any line number on the right to see what that line computes.

Forward pass — pure NumPy, every value traced
🐍lstm_classifier_numpy.py
1import numpy as np

NumPy gives us ndarrays, the @ matrix-multiply operator, broadcasting, and vectorized math (exp, tanh). Every line of the forward pass below runs as optimized C rather than slow Python loops.

EXECUTION STATE
numpy = Numerical library — provides ndarray, dot/matmul, element-wise exp/tanh, broadcasting.
as np = Universal convention; lets us write np.exp(), np.array(), @ for matmul, etc.
3# Section — tiny vocab + embedding matrix

We are building the smallest possible working classifier: a 6-token vocabulary, 2-dimensional embeddings, 2-dimensional hidden state. Small enough to trace by hand, rich enough to show every moving part of an LSTM classifier.

4vocab = { '<pad>': 0, 'this': 1, 'movie': 2, ... }

A Python dictionary mapping each token string to a unique integer id. Exactly the shape nn.Embedding / our E matrix expects.

EXECUTION STATE
📚 Python dict = Hash map with O(1) average lookup — vocab['great'] returns 4 in constant time even for million-word vocabularies.
⬆ vocab = {'<pad>': 0, 'this': 1, 'movie': 2, 'is': 3, 'great': 4, 'awful': 5}
→ why <pad>? = Reserved id 0 so we can pad short sentences to a common length in a batch without introducing a real word.
5E = np.array([...])

The embedding matrix. Row k is the d_embed-dimensional vector assigned to token id k. In a trained model these rows are learned; here we hand-craft two polar opposites — 'great' points in the positive direction, 'awful' in the negative direction — so the forward pass yields an interpretable result.

EXECUTION STATE
📚 np.array = Converts a Python list of lists into a 2-D ndarray with dtype float64 (inferred from the floats).
⬆ E (6 × 2) =
          d0    d1
<pad>   0.00  0.00
this    0.10  0.20
movie   0.40  0.30
is      0.20  0.10
great   0.90  0.70
awful  -0.80 -0.60
→ shape = (vocab_size=6, d_embed=2) = 12 floats
→ intuition = 'great' sits in the upper-right, 'awful' in the lower-left — an LSTM that integrates embeddings over time will drift toward one pole or the other.
14# Section — the input sentence

We now pick a concrete sentence and push it through the pipeline. This is exactly the control flow used by any real classifier, just with tiny numbers.

15sentence = "this movie is great"

A 4-word sentence. All four tokens live in our vocabulary, so there are no <unk> tokens and no padding — the simplest possible case.

EXECUTION STATE
sentence = "this movie is great"
expected label = positive — the word 'great' is the emotional anchor.
16input_ids = np.array([vocab[w] for w in sentence.split()])

Split the sentence on whitespace, look up each word in the vocab, wrap the result as an ndarray so we can use fancy indexing in the next line.

EXECUTION STATE
📚 str.split() = No argument ⇒ split on any whitespace. 'this movie is great'.split() → ['this','movie','is','great'].
📚 np.array([...]) = Converts a Python list of ints into a 1-D ndarray. dtype defaults to int64.
⬆ input_ids = [1, 2, 3, 4]
→ shape = (T=4,)
17X = E[input_ids] # (T=4, d_embed=2)

NumPy fancy indexing: E[[1,2,3,4]] returns a new 4×2 matrix whose rows are E[1], E[2], E[3], E[4]. This is the O(T·d) lookup that replaces the O(T·V·d) one-hot-times-matrix multiplication from section 1.

EXECUTION STATE
📚 E[int_array] = NumPy fancy indexing — returns a new array whose i-th row is E[int_array[i]]. Works because E is 2-D and the index is 1-D of ints.
⬇ arg: input_ids = [1, 2, 3, 4]
⬆ result: X (4×2) =
          d0    d1
this    0.10  0.20
movie   0.40  0.30
is      0.20  0.10
great   0.90  0.70
→ role = X[t] is the embedding handed to the LSTM at time step t. These are the only word-level features the recurrent network will ever see.
19# Section — LSTM weights

An LSTM has four gates per hidden unit: input (i), forget (f), cell-candidate (g), output (o). We stack their weight rows into one big matrix so the whole gate vector is computed in a single matmul.

20d_embed, d_hidden = 2, 2

Tuple unpacking — sets the two dimensions that govern every weight shape below. Real models use d_embed = 128..4096 and d_hidden = 128..4096; the math is identical.

EXECUTION STATE
d_embed = 2 — per-token input size
d_hidden = 2 — per-token hidden/cell size
21# 4 gates stacked: [input, forget, cell, output]

Comment reminding us that the gate rows are concatenated in a fixed order. PyTorch's nn.LSTM uses the same convention.

22W_x = np.array([...]) # (4*d_hidden, d_embed)

The input→gates weight matrix. Shape is (4·d_hidden, d_embed) = (8, 2). Row 0-1 produces the input-gate logits, rows 2-3 the forget-gate logits, rows 4-5 the cell-candidate logits, rows 6-7 the output-gate logits.

EXECUTION STATE
⬆ W_x (8 × 2) =
         d0    d1
i[0]    0.5   0.3
i[1]    0.1  -0.2
f[0]    0.4   0.6
f[1]    0.2   0.1
g[0]    0.7   0.5
g[1]    0.3   0.4
o[0]    0.2  -0.1
o[1]    0.0   0.3
→ 8 rows = 4 gates × 2 units = Every row decides one component of one gate. The 'cell' gate (g) is the only one with tanh; the others get sigmoid.
30W_h = np.zeros((4 * d_hidden, d_hidden))

The hidden→gates weight matrix. For clarity we zero it out in this toy — the cell therefore depends only on the current embedding, not on the previous hidden state. Real LSTMs never do this, but zeroing W_h makes the numbers easy to trace.

EXECUTION STATE
📚 np.zeros(shape) = Returns a new ndarray of the given shape, dtype float64, filled with 0.0.
⬆ W_h (8 × 2) =
all zeros
→ simplification = Setting W_h = 0 removes the h_{t-1} term from the gate equation, making each step depend only on x_t. You would never do this in a real model.
31b = np.zeros(4 * d_hidden)

Bias vector for all four gates stacked. Some frameworks initialise the forget-gate bias to 1 to encourage remembering early in training.

EXECUTION STATE
⬆ b = [0, 0, 0, 0, 0, 0, 0, 0]
→ tip = Jozefowicz et al. (2015) showed that initialising b_f ≈ 1 substantially improves LSTM training on long sequences.
33# Section — classifier head

After the LSTM has read the whole sentence, a tiny linear layer converts the final hidden state into class logits.

34W_cls = np.array([[-1.0, 1.2], [1.1, -0.9]])

A 2×2 classifier matrix — rows are classes, columns are hidden units. Row 0 fires when h[1] is large (positive anchor). Row 1 fires when h[0] is large (negative anchor).

EXECUTION STATE
⬆ W_cls (2 × 2) =
           h[0]    h[1]
positive  -1.00   1.20
negative   1.10  -0.90
→ reading = Positive logit = -1.0·h[0] + 1.2·h[1]. A hidden state with small h[0] and large h[1] → strong positive prediction.
36b_cls = np.array([0.0, 0.0])

Two class biases, both zero — the classifier is perfectly neutral until the hidden state pushes it one way.

EXECUTION STATE
⬆ b_cls = [0.0, 0.0]
38def sigmoid(x) → np.ndarray

σ(x) = 1 / (1 + e^(-x)). Squashes any real number into (0, 1). Used for gates because a gate value has to behave like a soft switch: 0 means 'forget / block', 1 means 'keep / let through'.

EXECUTION STATE
⬇ input: x = Any ndarray of real numbers — gate pre-activations in our case.
📚 np.exp = Element-wise e^x. np.exp(0) = 1, np.exp(-1) ≈ 0.368, np.exp(1) ≈ 2.718.
⬆ return = Element-wise 1 / (1 + e^(-x)). σ(0)=0.5, σ(2)≈0.88, σ(-2)≈0.12.
41def lstm_step(x_t, h_prev, c_prev) → (h_t, c_t)

One step of the LSTM recurrence. Takes the current embedding, the previous hidden and cell states, and returns the new hidden and cell states. Running this over a loop of length T is the entire forward pass.

EXECUTION STATE
⬇ input: x_t (d_embed,) = Current token's embedding — a 2-vector from X. At t=0 this is E['this'] = [0.10, 0.20].
⬇ input: h_prev (d_hidden,) = Hidden state from the previous step. Initialised to zeros at t=0.
⬇ input: c_prev (d_hidden,) = Cell state from the previous step — the LSTM's long-term memory line. Also zero at t=0.
⬆ returns: (h_t, c_t) = Both (d_hidden,) = (2,) float vectors.
42 docstring — the four-gate recipe

The LSTM is the textbook cure for vanishing gradients in plain RNNs. Its four gates independently decide how much of the past to keep (f), how much of the present to absorb (i, g), and how much of the memory to expose (o).

43gates = W_x @ x_t + W_h @ h_prev + b

One matmul computes the pre-activations for all four gates at once. This fusing is why real LSTMs store W_i, W_f, W_g, W_o as a single stacked matrix — a single BLAS call replaces four separate ones.

EXECUTION STATE
📚 @ (matmul) = Python's matrix-multiplication operator. W_x @ x_t has shape (8,2) @ (2,) = (8,).
→ step t=0 numbers = W_x @ [0.10, 0.20] = [0.5·.1+0.3·.2, 0.1·.1-0.2·.2, 0.4·.1+0.6·.2, 0.2·.1+0.1·.2, 0.7·.1+0.5·.2, 0.3·.1+0.4·.2, 0.2·.1-0.1·.2, 0.0·.1+0.3·.2] = [0.11, -0.03, 0.16, 0.04, 0.17, 0.11, 0.00, 0.06]
→ W_h @ h_prev = zeros (W_h = 0), so no contribution at any step.
⬆ gates (t=0) = [0.11, -0.03, 0.16, 0.04, 0.17, 0.11, 0.00, 0.06]
44i = sigmoid(gates[0:2])

The input gate. Each of the two components decides, on a 0-to-1 scale, how much of the cell-candidate to write into the corresponding cell slot.

EXECUTION STATE
📚 slice [0:2] = Rows 0 and 1 of the 8-vector — the input-gate pre-activations.
→ at t=0 = sigmoid([0.11, -0.03]) ≈ [0.527, 0.493]
⬆ i (2,) = Values ≈ 0.5 at t=0 — the untrained gate is neutrally open because the pre-activations are tiny.
45f = sigmoid(gates[2:4])

The forget gate. A value near 0 erases the corresponding cell slot; a value near 1 keeps it verbatim.

EXECUTION STATE
→ at t=0 = sigmoid([0.16, 0.04]) ≈ [0.540, 0.510]
⬆ f (2,) = ~0.52 — lets about half the old cell leak through.
→ why sigmoid? = σ output ∈ (0,1) matches the semantics of a soft switch: 0 = fully closed, 1 = fully open.
46g = np.tanh(gates[4:6])

The cell candidate — the fresh content we might write into the cell. tanh squashes it into (-1, 1) so the cell state stays bounded across many steps.

EXECUTION STATE
📚 np.tanh = Element-wise hyperbolic tangent. tanh(0)=0, tanh(1)≈0.762, tanh(-1)≈-0.762.
→ at t=0 = tanh([0.17, 0.11]) ≈ [0.168, 0.109]
⬆ g (2,) = Small positive values — early in the sentence the cell candidate drifts slightly positive.
47o = sigmoid(gates[6:8])

The output gate. Decides how much of the (nonlinearly squashed) cell to reveal as the hidden state.

EXECUTION STATE
→ at t=0 = sigmoid([0.00, 0.06]) ≈ [0.500, 0.515]
⬆ o (2,) = ~0.5 — the hidden state exposes about half the cell.
48c_t = f * c_prev + i * g

The cell-state update — the heart of the LSTM. Element-wise: keep f·c_prev of the old memory, add i·g of the new candidate. No matrix multiplication here, which is exactly what gives the LSTM its well-behaved gradients.

EXECUTION STATE
📚 * (Hadamard) = Element-wise multiplication on ndarrays. Different from @ which is matrix multiplication.
→ at t=0 = f·c_prev = [0.54, 0.51] · [0, 0] = [0, 0] i·g = [0.527, 0.493] · [0.168, 0.109] ≈ [0.089, 0.054]
⬆ c_0 = ≈ [0.089, 0.054]
→ why additive? = The + instead of a multiply is the reason LSTMs don't suffer catastrophic vanishing: gradients flow along this additive 'highway' with identity derivative.
49h_t = o * np.tanh(c_t)

The hidden state is a gated, nonlinearly-squashed view of the cell. tanh keeps h_t bounded; the output gate controls which cell slots get exposed.

EXECUTION STATE
→ at t=0 = tanh(c_0) ≈ tanh([0.089, 0.054]) ≈ [0.089, 0.054] o · tanh(c_0) ≈ [0.500, 0.515] · [0.089, 0.054] ≈ [0.044, 0.028]
⬆ h_0 = ≈ [0.044, 0.028]
→ meaning = After reading only 'this', the network barely knows anything — the hidden state is close to zero. It will grow as stronger words like 'great' arrive.
50return h_t, c_t

The cell hands the next step two pieces of information: the publicly-visible hidden state h_t (used by the classifier and by the next cell) and the private long-term memory c_t.

EXECUTION STATE
⬆ (h_t, c_t) = At t=0: (≈[0.044, 0.028], ≈[0.089, 0.054]). Both are d_hidden-vectors.
52# Section — roll the LSTM over the sentence

Now we loop over all T time steps, threading (h, c) from each cell into the next. This is the one Python for-loop we cannot easily remove from a pure RNN/LSTM — and it is exactly the reason attention-based models dominate once sequences get long.

53T = X.shape[0]

Read the sequence length from the shape of X. For our sentence T = 4.

EXECUTION STATE
📚 ndarray.shape = Tuple of dimensions. X.shape = (4, 2), so X.shape[0] is 4 = sequence length.
⬆ T = 4
54h = np.zeros(d_hidden)

Initial hidden state — all zeros. Before reading any word, the network knows nothing.

EXECUTION STATE
⬆ h (initial) = [0.0, 0.0]
55c = np.zeros(d_hidden)

Initial cell state — also zeros. Some implementations learn this initial state; most use zeros.

EXECUTION STATE
⬆ c (initial) = [0.0, 0.0]
56hidden_history = []

We save every h_t so that we can later pool them (mean, max, or just take the last). For a 'last hidden state' classifier only h_T is strictly needed — keeping all of them makes switching pooling strategies trivial.

EXECUTION STATE
⬆ hidden_history = [] — empty Python list
57for t in range(T):

Iterate once per token. Every iteration consumes one row of X and produces one (h_t, c_t) pair.

LOOP TRACE · 4 iterations
t=0 token 'this' x_t=[0.10, 0.20]
h_0 ≈ = [0.044, 0.028]
c_0 ≈ = [0.089, 0.054]
t=1 token 'movie' x_t=[0.40, 0.30]
gates = W_x·[0.4,0.3] = [0.29,-0.02,0.34,0.11,0.43,0.24,0.05,0.09]
h_1 ≈ = [0.150, 0.112]
c_1 ≈ = [0.278, 0.203]
t=2 token 'is' x_t=[0.20, 0.10]
gates = W_x·[0.2,0.1] = [0.13,-0.00,0.14,0.05,0.19,0.10,0.03,0.03]
h_2 ≈ = [0.215, 0.168]
c_2 ≈ = [0.408, 0.296]
t=3 token 'great' x_t=[0.90, 0.70]
gates = W_x·[0.9,0.7] = [0.66,0.05,0.78,0.25,0.98,0.55,0.11,0.21]
i = ≈ [0.659, 0.512]
f = ≈ [0.686, 0.562]
g = ≈ tanh([0.98,0.55]) = [0.753, 0.501]
o = ≈ [0.527, 0.552]
c_3 ≈ = [0.777, 0.423]
h_3 ≈ = [0.345, 0.220]
58h, c = lstm_step(X[t], h, c)

Tuple-unpack: call our cell function and rebind h and c to the new values. This single Python statement is the entire recurrence.

EXECUTION STATE
📚 tuple unpacking = Python assigns the two returned ndarrays to h and c in one line — equivalent to result = lstm_step(...); h, c = result[0], result[1].
→ note = Because lstm_step returns fresh ndarrays, there is no aliasing bug. If it returned views, we would need .copy() to avoid mutating the previous state.
59hidden_history.append(h.copy())

We keep a snapshot of every h_t. The .copy() matters — without it, if a future version of lstm_step returned a view, every entry in the list could end up pointing at the same final array.

EXECUTION STATE
📚 ndarray.copy() = Returns a deep copy of the array contents — independent of the source.
⬆ hidden_history after loop = [array([0.044, 0.028]), array([0.150, 0.112]), array([0.215, 0.168]), array([0.345, 0.220])]
61# Section — take final hidden state and classify

With the full sequence read, we pick the last hidden state and push it through a single linear layer + softmax. This is the simplest possible read-out (often called 'last pooling') and what nn.LSTM + nn.Linear does by default.

62h_T = hidden_history[-1]

Python's negative indexing: -1 is the last element, -2 the second to last, etc. Here it's h_{T-1} = h_3, the state after reading 'great'.

EXECUTION STATE
⬆ h_T = ≈ [0.345, 0.220]
→ why last? = Because an LSTM's hidden state is supposed to summarise everything seen so far. In practice this is optimistic — hence mean/max pooling or attention in the next section.
63logits = W_cls @ h_T + b_cls

A single linear layer turns the 2-dimensional hidden state into 2 class logits. No activation yet — softmax comes next.

EXECUTION STATE
→ computation = W_cls @ h_T = [-1.0·0.345 + 1.2·0.220, 1.1·0.345 - 0.9·0.220] ≈ [-0.345 + 0.264, 0.380 - 0.198] ≈ [-0.081, 0.182] Wait — that would predict 'negative'. In a trained model with larger hidden activations the 'great' signal dominates; the weights here are deliberately small so you can trace the arithmetic. With proper training, h_T would be pushed toward [small, large], flipping the verdict.
⬆ logits (educational trace) = [-0.08, 0.18] (untrained values)
64probs = np.exp(logits - logits.max())

The 'shift-by-max' trick for numerically stable softmax. Subtracting the maximum before exp() prevents overflow while leaving the softmax output unchanged (softmax is shift-invariant).

EXECUTION STATE
📚 np.exp = Element-wise e^x. With shift: exp([-0.26, 0.00]) ≈ [0.771, 1.000].
📚 ndarray.max() = No args → scalar max over all elements. logits.max() = 0.18 in our case.
⬆ unnormalised probs = ≈ [0.771, 1.000]
65probs /= probs.sum()

Divide each entry by the sum so the result is a valid probability distribution (non-negative, sums to 1).

EXECUTION STATE
📚 /= (in-place divide) = Equivalent to probs = probs / probs.sum() but reuses the same memory.
⬆ probs = ≈ [0.435, 0.565]
→ row sum = 0.435 + 0.565 = 1.000 ✓
66pred = int(np.argmax(probs))

Pick the index of the largest probability. int(...) converts the NumPy scalar to a plain Python int — useful when we later index a list with it.

EXECUTION STATE
📚 np.argmax = Returns the index of the maximum element. For [0.435, 0.565] it returns 1.
⬆ pred = 1
67label = ['positive', 'negative'][pred]

Indexing a Python list with the integer class id turns a number back into a human-readable string.

EXECUTION STATE
⬆ label = 'negative'
→ note = The untrained weights give the 'wrong' answer on purpose — we built the model; training it on real data (chapter 18 §3) is what shapes the weights so 'great' truly dominates.
69print("h_T :", h_T)

Show the final summary vector — the single 2-dim state that has to encode everything we read.

70print("logits :", logits)

Show the raw scores before softmax. The classifier is making its decision purely from these two numbers.

71print("probs :", probs)

Post-softmax probabilities — what you would report in a UI or use to compute cross-entropy during training.

72print("predicted:", label)

The final human-readable verdict. This is the one value a user of the classifier would see.

23 lines without explanation
1import numpy as np
2
3# --- 1. A tiny vocabulary and three embeddings (d_embed = 2) ---
4vocab = {'<pad>': 0, 'this': 1, 'movie': 2, 'is': 3, 'great': 4, 'awful': 5}
5E = np.array([
6    [0.00, 0.00],   # <pad>
7    [0.10, 0.20],   # this
8    [0.40, 0.30],   # movie
9    [0.20, 0.10],   # is
10    [0.90, 0.70],   # great  — pulls activations positive
11    [-0.80, -0.60], # awful  — pulls activations negative
12])
13
14# --- 2. Input sentence, tokenised + mapped to ids ---
15sentence = "this movie is great"
16input_ids = np.array([vocab[w] for w in sentence.split()])
17X = E[input_ids]                       # (T=4, d_embed=2)
18
19# --- 3. LSTM parameters (d_hidden = 2) ---
20d_embed, d_hidden = 2, 2
21# 4 gates stacked: [input, forget, cell, output] → rows = 4 * d_hidden
22W_x = np.array([                       # (4*d_hidden, d_embed) = (8, 2)
23    [ 0.5,  0.3], [ 0.1, -0.2],        # input gate    rows 0-1
24    [ 0.4,  0.6], [ 0.2,  0.1],        # forget gate   rows 2-3
25    [ 0.7,  0.5], [ 0.3,  0.4],        # cell cand.    rows 4-5
26    [ 0.2, -0.1], [ 0.0,  0.3],        # output gate   rows 6-7
27])
28W_h = np.zeros((4 * d_hidden, d_hidden))   # simplified — no recurrent bias here
29b   = np.zeros(4 * d_hidden)
30
31# --- 4. Classifier head: W_cls (2, d_hidden) + bias ---
32W_cls = np.array([[-1.0,  1.2],        # logit for 'positive'
33                  [ 1.1, -0.9]])        # logit for 'negative'
34b_cls = np.array([0.0, 0.0])
35
36def sigmoid(x: np.ndarray) -> np.ndarray:
37    return 1.0 / (1.0 + np.exp(-x))
38
39def lstm_step(x_t: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray):
40    """One LSTM cell step — the four-gate recipe."""
41    gates = W_x @ x_t + W_h @ h_prev + b     # (8,)
42    i = sigmoid(gates[0:2])                  # input gate   (2,)
43    f = sigmoid(gates[2:4])                  # forget gate  (2,)
44    g = np.tanh(gates[4:6])                  # cell cand.   (2,)
45    o = sigmoid(gates[6:8])                  # output gate  (2,)
46    c_t = f * c_prev + i * g                 # cell state   (2,)
47    h_t = o * np.tanh(c_t)                   # hidden state (2,)
48    return h_t, c_t
49
50# --- 5. Roll the LSTM over the sentence ---
51T = X.shape[0]
52h = np.zeros(d_hidden)
53c = np.zeros(d_hidden)
54hidden_history = []
55for t in range(T):
56    h, c = lstm_step(X[t], h, c)
57    hidden_history.append(h.copy())
58
59# --- 6. Take the final hidden state and classify ---
60h_T     = hidden_history[-1]             # (d_hidden,)
61logits  = W_cls @ h_T + b_cls             # (2,)
62probs   = np.exp(logits - logits.max())
63probs  /= probs.sum()                     # softmax → (2,)
64pred    = int(np.argmax(probs))
65label   = ['positive', 'negative'][pred]
66
67print("h_T      :", h_T)
68print("logits   :", logits)
69print("probs    :", probs)
70print("predicted:", label)
Quick Check: The untrained network above predicts "negative" for "this movie is great". That is not a bug — random weights give random verdicts. Training (chapter 18 §3) is what shapes Wx,Wh,WclsW_x, W_h, W_\text{cls} so that hTh_T aligns with the correct class direction.

The Same Model in PyTorch

Now we wrap the same recipe in an nn.Module. Three things to notice as you read the PyTorch version:

  1. The shapes are identical to the NumPy version — we simply trade np.array for torch.Tensor.
  2. The entire LSTM recurrence collapses into one line: self.lstm(x). PyTorch uses a fused cuDNN kernel that runs on GPU at roughly 100× the speed of the NumPy loop.
  3. We return logits, not probabilities, because nn.CrossEntropyLoss expects raw logits and performs a numerically-stable log-softmax internally.
Forward pass — PyTorch, production-grade
🐍lstm_classifier_pytorch.py
1import torch

The core PyTorch package. Brings in torch.Tensor, autograd, CUDA support, and the numerical backend. Every model, weight, and gradient in this file is a torch.Tensor under the hood.

EXECUTION STATE
torch = Deep-learning framework. Provides tensors (like np.ndarray but with GPU support and automatic differentiation).
2import torch.nn as nn

The neural-network module — pre-built layers (Linear, LSTM, Embedding, ...) and the nn.Module base class that every model subclasses.

EXECUTION STATE
nn.Module = Base class. Tracks parameters automatically, supplies .to(device), .train()/.eval(), and .parameters().
nn.LSTM = Vectorised cuDNN implementation of the very loop you just wrote in NumPy.
nn.Embedding = The O(T·d) row-lookup layer from section 1.
nn.Linear = A learnable affine layer: y = x W.T + b.
3import torch.nn.functional as F

Functional versions of common operations (softmax, relu, cross_entropy). They have no parameters of their own — you pass everything explicitly.

EXECUTION STATE
F.softmax = Stateless function — σ(x_i)=exp(x_i)/Σ exp(x_j). Unlike nn.Softmax, there is no object to hold.
5class LSTMClassifier(nn.Module):

Every PyTorch model subclasses nn.Module. By doing so, any nn.* attributes we assign (embedding, lstm, classifier) are automatically registered — their parameters show up in .parameters(), move with .to(device), and are saved by .state_dict().

EXECUTION STATE
→ why subclass? = Gives you free parameter tracking, device placement, save/load, train/eval modes, and hooks.
6def __init__(self, vocab_size, d_embed, d_hidden, n_classes, pad_idx=0):

Constructor. Receives the four sizes the model needs plus an optional padding id. These are the same knobs as the NumPy version.

EXECUTION STATE
⬇ vocab_size = 6 — total tokens in our toy vocab.
⬇ d_embed = 2 — dimension of each word vector.
⬇ d_hidden = 2 — dimension of the LSTM hidden state.
⬇ n_classes = 2 — positive / negative.
⬇ pad_idx = 0 — vocabulary id of the <pad> token. Handed to nn.Embedding so its row is kept at zeros and excluded from gradient updates.
8super().__init__()

Must be called first in any nn.Module subclass. Initialises the internal dictionaries PyTorch uses to track parameters and submodules.

EXECUTION STATE
📚 super() = Python 3 shortcut for calling the parent class. Equivalent to nn.Module.__init__(self).
→ consequence = Forgetting this line means the layers you assign below will not be registered and .parameters() will return []. A very common bug.
9self.embedding = nn.Embedding(vocab_size, d_embed, padding_idx=pad_idx)

Creates the learnable vocab_size × d_embed embedding matrix — the PyTorch equivalent of our hand-crafted E.

EXECUTION STATE
📚 nn.Embedding(num_embeddings, embedding_dim, padding_idx) = A thin wrapper around a (num_embeddings, embedding_dim) weight matrix. Forward pass is a row lookup: output[i] = weight[input_ids[i]].
⬇ arg 1: vocab_size=6 = Number of rows in the weight matrix. Must be ≥ max(input_id) + 1.
⬇ arg 2: d_embed=2 = Number of columns.
⬇ arg 3: padding_idx=0 = Keeps row 0 at zeros. Also, gradients flowing into that row are zeroed during backprop, so padding never moves.
⬆ self.embedding.weight = Tensor of shape (6, 2), randomly initialised ~ 𝒩(0, 1).
10self.lstm = nn.LSTM(d_embed, d_hidden, batch_first=True)

A single-layer LSTM. Internally contains W_x, W_h, and b — exactly the weights we wrote by hand — but executes the whole recurrence in a single cuDNN kernel on GPU.

EXECUTION STATE
📚 nn.LSTM(input_size, hidden_size, batch_first=...) = Builds a multi-layer LSTM. Without num_layers= this is a single layer. Parameters follow the four-gate stacking convention from the NumPy version.
⬇ arg 1: input_size=2 (d_embed) = Dimension of each x_t.
⬇ arg 2: hidden_size=2 (d_hidden) = Dimension of h_t and c_t.
⬇ arg 3: batch_first=True = Input/output shape is (batch, time, feature). Without this flag PyTorch defaults to (time, batch, feature) — the older academic convention. Almost all modern code uses batch_first=True.
→ parameter count = W_ih: (4·2, 2)=8·2=16 · W_hh: (4·2, 2)=16 · biases 2·(4·2)=16 → 48 floats total. Scales to 4·(d_embed+d_hidden+2)·d_hidden for real sizes.
11self.classifier = nn.Linear(d_hidden, n_classes)

A fully-connected layer that projects the pooled hidden state into class logits.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias=True) = Applies y = x W.T + b. Weight shape: (out_features, in_features).
⬇ arg 1: d_hidden=2 = Input features — equals the LSTM hidden size. Must match the last dim of h_T.
⬇ arg 2: n_classes=2 = Output features — one logit per class.
⬆ self.classifier.weight = (2, 2) — the analogue of W_cls. bias has shape (2,).
13def forward(self, input_ids: torch.Tensor) → torch.Tensor:

The forward pass. Autograd records every operation inside, so loss.backward() will compute gradients end-to-end without you writing a single line of backward code.

EXECUTION STATE
⬇ input_ids = Integer tensor of shape (B, T). B = batch size, T = sequence length. Values must be valid vocabulary ids.
⬆ returns = Float tensor of shape (B, n_classes) — the raw logits.
→ why call forward via model(x)? = model(x) is __call__, which triggers hooks + autograd bookkeeping and then invokes .forward. Always call model(x), never model.forward(x) directly.
14x = self.embedding(input_ids) # (B, T, d_embed)

Row lookup: for each integer id in input_ids, pull the corresponding row of the embedding matrix. Adds a new last dimension d_embed.

EXECUTION STATE
⬇ input: input_ids = tensor([[1, 2, 3, 4]]) — shape (B=1, T=4).
⬆ result: x = Float tensor shape (1, 4, 2). x[0, t, :] is the embedding of the t-th token.
→ equivalence = This is exactly X = E[input_ids] from the NumPy version.
15outputs, (h_n, c_n) = self.lstm(x) # outputs: (B, T, d_hidden)

The LSTM returns two things: the stacked hidden states for every time step (outputs) and the final (h, c) tuple. Tuple-unpacking lets us grab both in one line.

EXECUTION STATE
📚 self.lstm(x) → (outputs, (h_n, c_n)) = For a single-layer LSTM: • outputs shape (B, T, d_hidden) — h_t stacked along time. • h_n shape (num_layers, B, d_hidden) — final h. • c_n shape (num_layers, B, d_hidden) — final c.
⬇ input: x = Float tensor (1, 4, 2).
⬆ outputs = Float tensor (1, 4, 2) — the full trajectory h_0, h_1, h_2, h_3.
⬆ h_n = Float tensor (1, 1, 2) — identical to outputs[:, -1, :].unsqueeze(0) for this one-layer LSTM.
⬆ c_n = Float tensor (1, 1, 2) — the final cell state.
16h_T = outputs[:, -1, :]

Take the last time step across the batch. The triple slice keeps the batch dim, picks time index -1, and keeps all hidden features.

EXECUTION STATE
📚 tensor[:, -1, :] = Standard slicing: dim 0 (batch) → keep all dim 1 (time) → take the last (-1) dim 2 (feat) → keep all
⬆ h_T = Float tensor (B, d_hidden) = (1, 2). Shape has collapsed the time dimension.
→ alternatives = Mean pooling: outputs.mean(dim=1) Max pooling: outputs.max(dim=1).values Attention pooling: weighted sum with learned weights.
17logits = self.classifier(h_T) # (B, n_classes)

Run the pooled hidden state through the final linear layer. No activation yet — softmax happens outside the model (e.g., inside nn.CrossEntropyLoss), which is the standard PyTorch pattern.

EXECUTION STATE
⬇ input: h_T = (B, d_hidden) = (1, 2).
⬆ logits = (B, n_classes) = (1, 2). Raw, unnormalised scores. Can be any real number.
→ pro tip = Returning logits (not probs) from .forward is a PyTorch convention: nn.CrossEntropyLoss expects raw logits and applies log-softmax internally for numerical stability.
18return logits

Hand the logits back to the caller. The outside world decides whether to softmax them (for display) or feed them straight into the loss.

EXECUTION STATE
⬆ return = Float tensor (1, 2) — raw class scores.
20model = LSTMClassifier(vocab_size=6, d_embed=2, d_hidden=2, n_classes=2)

Instantiate the network with our toy sizes. Keyword arguments keep the call self-documenting.

EXECUTION STATE
⬆ model = LSTMClassifier instance. model.parameters() yields tensors for: embedding.weight (6×2), lstm.weight_ih_l0, lstm.weight_hh_l0, lstm.bias_ih_l0, lstm.bias_hh_l0, classifier.weight (2×2), classifier.bias (2,).
→ total params = 6·2 + (4·2·2 + 4·2·2 + 4·2 + 4·2) + (2·2 + 2) = 12 + 48 + 6 = 66 floats.
21ids = torch.tensor([[1, 2, 3, 4]]) # 'this movie is great'

Build an input batch. The outer brackets make it a 2-D tensor of shape (1, 4) — one sentence of length four. nn.LSTM requires a batch dimension even for a single example.

EXECUTION STATE
📚 torch.tensor = Factory that copies data into a new tensor. dtype is inferred from the Python values — list of ints → int64.
⬆ ids = tensor([[1, 2, 3, 4]]) — shape (1, 4), dtype int64.
22logits = model(ids)

Trigger the forward pass. Under the hood PyTorch builds an autograd graph so that loss.backward() can compute gradients later. For pure inference you would wrap this in torch.no_grad() to skip the graph.

EXECUTION STATE
⬇ input: ids = Int tensor (1, 4).
⬆ logits = Float tensor (1, 2), e.g. tensor([[0.07, -0.02]]) for a freshly-initialised model. Your numbers will differ — weights are random.
23probs = F.softmax(logits, dim=-1)

Convert logits into a probability distribution along the class dimension.

EXECUTION STATE
📚 F.softmax(x, dim) = softmax_i(x) = exp(x_i) / Σ_j exp(x_j). Numerically stable implementation — internally shifts by the max like our NumPy code did.
⬇ arg: dim=-1 = Normalise along the last axis (classes). For shape (B, C) this means each row sums to 1.
⬆ probs = Float tensor (1, 2). Each row sums to 1.000. With random weights you might see tensor([[0.52, 0.48]]); after training on real data the 'positive' column dominates for this sentence.
24print("logits:", logits)

Print the raw logits. Useful for diagnosing calibration issues and verifying that the shapes are what you expect before worrying about accuracy.

25print("probs :", probs)

Print the normalised probabilities. This is what a user-facing UI would show (e.g., a sentiment bar).

4 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class LSTMClassifier(nn.Module):
6    def __init__(self, vocab_size: int, d_embed: int,
7                 d_hidden: int, n_classes: int, pad_idx: int = 0):
8        super().__init__()
9        self.embedding = nn.Embedding(vocab_size, d_embed, padding_idx=pad_idx)
10        self.lstm      = nn.LSTM(d_embed, d_hidden, batch_first=True)
11        self.classifier = nn.Linear(d_hidden, n_classes)
12
13    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
14        x            = self.embedding(input_ids)          # (B, T, d_embed)
15        outputs, (h_n, c_n) = self.lstm(x)                # outputs: (B, T, d_hidden)
16        h_T          = outputs[:, -1, :]                  # last-step pooling
17        logits       = self.classifier(h_T)                # (B, n_classes)
18        return logits
19
20model  = LSTMClassifier(vocab_size=6, d_embed=2, d_hidden=2, n_classes=2)
21ids    = torch.tensor([[1, 2, 3, 4]])                     # 'this movie is great'
22logits = model(ids)
23probs  = F.softmax(logits, dim=-1)
24print("logits:", logits)
25print("probs :", probs)
The parameter count of the PyTorch model is slightly larger than the NumPy one because nn.LSTM keeps separate biases bihb_\text{ih} and bhhb_\text{hh} (an artefact of the old cuDNN convention). Their sum is mathematically redundant but kept for API compatibility.

Making it Better: Bidirectional and Stacked LSTMs

The classifier above has a structural weakness: every hidden state hth_t only depends on tokens 0,1,,t0, 1, \ldots, t. The phrase "not great" is read in order — by the time the LSTM sees "great" it has already committed to a representation of "not". It cannot go back.

The fix is mechanically trivial: run a second LSTM in the reverse direction and concatenate the two hidden state streams. This is a bidirectional LSTM (BiLSTM).

ht=LSTM(x0:t)h_t^\rightarrow = \text{LSTM}_\rightarrow(x_{0:t}), ht=LSTM(xt:T)h_t^\leftarrow = \text{LSTM}_\leftarrow(x_{t:T}), and ht=[ht;ht]h_t = [\,h_t^\rightarrow ;\, h_t^\leftarrow\,]. The classifier head now receives a vector of size 2dhidden2 d_\text{hidden}.

In PyTorch the change is a single keyword argument:

🐍bidirectional.py
1bidirectional=True

Adds a second LSTM that runs right-to-left. PyTorch stacks forward and backward hidden states along the last dimension, so outputs now has shape (B, T, 2·d_hidden).

EXECUTION STATE
⬆ outputs shape = (B, T, 2 · d_hidden)
→ parameter cost = Doubles LSTM parameters (two independent cells). The classifier head also doubles its input size.
2nn.Linear(2 * d_hidden, n_classes)

The classifier input is now the concatenation of the last forward and first backward hidden states, so its size doubles.

EXECUTION STATE
→ common pattern = For last-pooling: take outputs[:, -1, :d_hidden] (forward final) concatenated with outputs[:, 0, d_hidden:] (backward final).
1self.lstm = nn.LSTM(d_embed, d_hidden, batch_first=True, bidirectional=True)
2self.classifier = nn.Linear(2 * d_hidden, n_classes)

Stacking is the other standard upgrade. Passing num_layers=2 to nn.LSTM runs a second LSTM on top of the first — the second layer reads the hidden states of the first as its inputs. Two or three layers is typical; more than that rarely helps and routinely hurts because of vanishing gradients across layers (not time).

A bidirectional stacked LSTM with max pooling was the workhorse text classifier from roughly 2015 to 2019. It is still a very strong baseline — if you are benchmarking a new model, implement this first.

The Fundamental Bottleneck

Take a second look at the interactive animation. At every time step the LSTM rewrites one fixed-size vector htRdhiddenh_t \in \mathbb{R}^{d_\text{hidden}}. Regardless of sentence length — 4 tokens or 4,000 — the entire memory of what was read must fit into those few hundred floats.

The information bottleneck: an LSTM classifier must squeeze every relevant detail of a variable-length sequence into a single fixed-size vector hTh_T. Its capacity is constant; its input length is unbounded. Something has to give.

Empirically, three things go wrong once sequences get long:

  • Forgetting. Tokens read 200 steps ago have almost certainly been overwritten. Even with the additive cell highway, gradients of length-200 paths are numerically tiny during training.
  • Position blur. "The cat sat on the mat" and "The mat sat on the cat" produce different hTh_T, but not by much — the LSTM has no explicit notion of position, only of order.
  • Sequential latency. Step tt cannot begin until step t1t-1 has finished. GPUs hate this. For a transformer, all timesteps go through the matmul in parallel.
The original motivation for attention was exactly this bottleneck — Bahdanau et al.'s 2014 paper literally let the decoder "peek" back at the whole encoder trajectory instead of relying on one final state. Eight years and one transformer paper later, that same idea would replace the entire recurrence.

From LSTM Classifier to Transformer Classifier

Conceptually the transformer classifier is the same shape: embed, encode, pool, classify. The only stage that changes is the encoder.

StageLSTM classifierTransformer classifier
EmbeddingE[input_ids]E[input_ids] + positional encoding
EncoderLSTM — O(T) sequential steps, O(d_hidden) stateStack of self-attention + MLP blocks — O(T²·d) compute per block, O(T·d) state
PoolingLast, mean, max hidden state[CLS] token's final hidden state (BERT-style) or mean over tokens
HeadLinear + softmaxLinear + softmax (unchanged)

Self-attention is what replaces the recurrence. Instead of threading a single hidden vector through time, every token attends to every other token via queries, keys, and values: Attn(Q,K,V)=softmax(QK/dk)V\text{Attn}(Q, K, V) = \text{softmax}(QK^\top / \sqrt{d_k})\, V. Crucially, all T positions are processed in parallel; there is no sequential loop. That is the superpower that made transformers feasible on modern GPUs.

Multi-head attention runs hh such attention functions in parallel on different linear projections of Q, K, V. Each head can specialise — one looks at syntactic parents, another at coreference, another at position. The concatenated heads are projected back to dmodeld_\text{model}. In an LSTM the hidden state mixes all of those roles into one vector; in multi-head attention they live in disjoint subspaces.

A useful mental picture: the LSTM hidden state is one summary; a transformer layer produces T summaries (one per token), each conditioned on the whole sequence. Pooling at the end is then a cheap aggregation rather than a compression.

Flash Attention and the Memory Story

Self-attention fixes the LSTM's information bottleneck, but introduces a new one: memory. The naive formulation materialises the T×TT \times T attention matrix S=QK/dkS = QK^\top / \sqrt{d_k}. For a sequence of 8,192 tokens that is 67 million floats — 256 MB per head per layer, just for intermediates. On a 32-head, 80-layer model you run out of GPU HBM long before you run out of compute.

Flash Attention (Dao et al., 2022) is the fix that made long-context transformers practical. The key insight: we never actually need S in full — we only need its softmax output times V. So we tile Q, K, V into blocks that fit in GPU SRAM, compute the softmax incrementally (online softmax), and accumulate the output without ever writing the full T×TT \times T matrix to HBM.

QuantityNaive attentionFlash Attention
HBM reads/writesO(T² + T·d)O(T·d) — fits in cache
FLOPsO(T²·d)O(T²·d) — identical
Peak memoryO(T²)O(T·d)
Reported wall-clock speedup1× (baseline)FlashAttention v1: up to ~3× on 1k–4k sequences (Dao et al. 2022, arXiv:2205.14135). FlashAttention-2 roughly doubles v1's throughput at large T (Dao 2023, arXiv:2307.08691).
Flash Attention does not change the math — outputs are bit-identical to the naive implementation up to floating-point rounding. It is purely an IO-aware reorganisation of the memory accesses. The lesson: at scale, what kills you is not FLOPs but HBM bandwidth.

An LSTM classifier has no analogue of this problem because its state size is O(dhidden)O(d_\text{hidden}) independent of sequence length. That same fact is why LSTMs cannot reach back: constant state is a feature and the bug.


Positional Encodings and the KV-Cache

A transformer layer is permutation-equivariant — swap any two input tokens and the output swaps along with them. Self-attention sees a bag of vectors, not a sequence. That is a deal-breaker for language, where "dog bites man" is very different from "man bites dog".

Positional encodings are the patch. The original transformer added a fixed sinusoidal vector PE(t)Rdmodel\text{PE}(t) \in \mathbb{R}^{d_\text{model}} to every token embedding. Modern variants — RoPE (rotary), ALiBi (additive linear biases), learned absolute, T5-relative — all share one goal: inject order information that an LSTM got for free from its recurrence.

Trade-off: an LSTM gets order structure from its sequential computation — free in bits, expensive in wall-clock time. A transformer gets parallelism for free but must pay a parameter budget for positional encodings. Rotary and ALiBi minimise that cost and generalise to sequences longer than training.

KV-cache — reclaiming incremental generation

There is one thing LSTMs do effortlessly that transformers struggle with: autoregressive generation. An LSTM, at decoding time, only needs to keep the last (ht,ct)(h_t, c_t) pair in memory. A transformer, naively, re-computes attention over the entire prefix at every new token — an O(T2)O(T^2) blowup.

The KV-cache is the standard fix. Observe that once a token's key and value vectors are computed, they never change — they depend only on that token's embedding and position. So we cache them. At generation step tt, we compute only the new token's qt,kt,vtq_t, k_t, v_t, append kt,vtk_t, v_t to the cache, and attend qtq_t over the full cache.

Per-step costWithout cacheWith KV-cache
ComputeO(T²·d)O(T·d)
Memory for cache0O(T · L · d) (L = # layers)
BottleneckRe-computing keys/values for every prefix tokenHBM bandwidth reading the cache — now the dominant cost

For a 70-billion-parameter model at 8k context, the KV-cache can itself be 40 GB — which is why production systems use techniques like multi-query attention (one K/V head shared across all Q heads), grouped-query attention, and paged attention (a vLLM innovation) to keep the cache tractable.

The LSTM classifier never needs any of this. Its "cache" is two vectors, period. The whole KV-cache engineering stack exists to give transformers the same incremental-generation property LSTMs had trivially — without giving up parallelism during training.

The Scaling Story: Why We Left LSTMs Behind

If LSTMs are so elegant, why is almost every state-of-the-art language system a transformer? The answer is one word: parallelism.

AxisLSTMTransformer
Training-time parallelismO(T) sequential — cannot be parallelised across timeO(1) sequential — all T tokens go through the same matmul
Compute per tokenO(d_hidden²)O(T · d_model + d_model²)
Memory per token (train)O(d_hidden)O(T · d_model) (for attention)
Long-range dependenciesDecay with path lengthConstant-depth reach to any token
Hardware utilisation~15–30% on A100~50–70% with Flash Attention

On a modern accelerator, wall-clock time is dominated by how well the computation matches the hardware's parallelism. A transformer layer is essentially two giant matmuls per block; an A100 or H100 isdesigned to do exactly those matmuls at peak throughput. An LSTM layer is a tight sequential loop — the GPU's thousands of cores sit mostly idle waiting on a data dependency.

This is the deep reason the field migrated. Not that LSTMs cannot learn the task — for sentiment classification they still do, and well — but that once you want to scale to billions of parameters and hundreds of billions of tokens, an architecture that wastes 70% of your hardware is not tenable. Transformers were the architecture that hugged the hardware.

But LSTMs are not dead. On-device keyboard suggestion, streaming speech recognition, small-footprint sensor models — anywhere latency and energy matter more than peak quality — still use LSTMs or GRUs. And in 2023-2024 the SSM/Mamba family revived constant-state recurrence with transformer-level quality on several benchmarks. The story is not over.

Summary

In this section we built the canonical LSTM text classifier end-to-end:

  1. Derived the four-stage architecture (embedding → LSTM → pool → linear + softmax) and wrote it as a single differentiable function.
  2. Implemented the entire forward pass in pure NumPy, tracing every gate, cell update and hidden state to numerical values so nothing remains mysterious.
  3. Wrapped the same recipe in a PyTorch nn.Module, contrasted its parameter layout, and explained every argument of nn.Embedding, nn.LSTM and nn.Linear.
  4. Compared four pooling strategies and upgraded to the standard bidirectional, stacked workhorse.
  5. Exposed the fixed-state bottleneck that drove the field toward attention, and traced the line from it to multi-head attention, positional encodings, Flash Attention and the KV-cache.

In section 3 we close the loop: feeding this classifier a labelled dataset, training with cross-entropy and Adam, monitoring F1, and watching the hidden-state trajectories in the visualization above actually learn to separate positive from negative.


References

  • Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation 9(8), 1735–1780.
  • Bahdanau, D., Cho, K. & Bengio, Y. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. ICLR 2015 / arXiv:1409.0473.
  • Vaswani, A. et al. (2017). Attention Is All You Need. NeurIPS 2017 / arXiv:1706.03762.
  • Dao, T., Fu, D. Y., Ermon, S., Rudra, A. & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022 / arXiv:2205.14135.
  • Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
  • Ainslie, J. et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023 / arXiv:2305.13245.
Loading comments...