Chapter 9
11 min read
Section 36 of 121

LSTM Cell Mathematics

Bidirectional LSTM Encoder

What §3.3 Established

Section 3.3 introduced the LSTM cell as a working-memory analogue: three gates (input / forget / output) plus a candidate state, with the cell-state update ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t being the magic line that lets gradients flow through long sequences. We worked through a SCALAR cell to keep the math visible. This section shows the same algebra in vector form - the production-grade calculation that nn.LSTM does internally.

From Scalar Cell to Vector Cell

With xtRDx_t \in \mathbb{R}^D and ht,ctRHh_t, c_t \in \mathbb{R}^H, every gate becomes a linear-plus-non-linearity from (xt,ht1)(x_t, h_{t-1}):

it=σ ⁣(Wixxt+Wihht1+bi)i_t = \sigma\!\bigl(W_{ix} x_t + W_{ih} h_{t-1} + b_i\bigr)

ft=σ ⁣(Wfxxt+Wfhht1+bf)f_t = \sigma\!\bigl(W_{fx} x_t + W_{fh} h_{t-1} + b_f\bigr)

gt=tanh ⁣(Wgxxt+Wghht1+bg)g_t = \tanh\!\bigl(W_{gx} x_t + W_{gh} h_{t-1} + b_g\bigr)

ot=σ ⁣(Woxxt+Wohht1+bo)o_t = \sigma\!\bigl(W_{ox} x_t + W_{oh} h_{t-1} + b_o\bigr)

Then ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t and ht=ottanh(ct)h_t = o_t \odot \tanh(c_t). The only difference from §3.3 is that everything is now a coordinate-wise vector op of length H, and the W matrices have shape (H, D) or (H, H).

The packing trick. PyTorch concatenates the four gate-weight matrices into one tall (4H, D) matrix and the four bias vectors into one (4H,) vector. This lets the four gate pre-activations be computed in a SINGLE matmul call. Slicing then peels them apart for the gate-specific non-linearities.

Interactive: Step-Trace (Recap)

The 8-cycle pulse trace from §3.3, reproduced. The vector form does the same thing per coordinate that the scalar form did once.

Loading LSTM step-trace…

Python: Vector LSTM Cell From Scratch

The class below packs all 4 gate weights into one matrix and does the gate computation in one matmul (matching nn.LSTM's internal layout). Run on a 30-cycle, 64-D input with hidden size 256.

Vector LSTM cell with packed weight matrix
🐍lstm_cell_vector.py
1import numpy as np

Standard alias.

5class LSTMCellVector:

Generalises §3.3's scalar LSTMCellMicro to vector states. Same 4-gate algebra; matrix multiplications instead of scalar.

6def __init__(self, input_size, hidden_size, seed=0):

Constructor takes vector dimensions plus a seed for reproducible init.

EXECUTION STATE
input: input_size = D - feature dim of x_t (64 in our backbone)
input: hidden_size = H - feature dim of h_t and c_t (256)
7rng = np.random.default_rng(seed)

Modern NumPy RNG (preferred over np.random.seed for new code).

8H, D = hidden_size, input_size

Local aliases for readability.

11self.W_x = rng.standard_normal((4 * H, D)).astype(np.float32) / np.sqrt(D)

ALL 4 gate-input weights packed into ONE matrix of shape (4H, D). PyTorch does this internally - one matmul faster than four. Init with std=1/sqrt(D) - poor man's Xavier.

EXECUTION STATE
W_x.shape = (1024, 64) for H=256, D=64
→ packing convention = Rows 0..H-1 are input gate, H..2H-1 forget, 2H..3H-1 candidate, 3H..4H-1 output. PyTorch follows this.
12self.W_h = rng.standard_normal((4 * H, H)).astype(np.float32) / np.sqrt(H)

Hidden-to-gate weights; same packing.

EXECUTION STATE
W_h.shape = (1024, 256)
13self.b = np.zeros(4 * H, dtype=np.float32)

Bias vector packing all 4 gates.

14self.b[H:2*H] = 1.0

Forget bias = 1 - the standard initialisation trick. PyTorch nn.LSTMCell does this by default.

15self.H = H

Stash for use in step().

17def step(self, x, h_prev, c_prev):

One timestep forward. Vector inputs.

EXECUTION STATE
input: x (D,) = Current timestep input vector
input: h_prev (H,) = Previous hidden state
input: c_prev (H,) = Previous cell state
18sig = lambda z: 1.0 / (1.0 + np.exp(-z))

Sigmoid - element-wise on vectors.

19H = self.H

Local alias.

22z = self.W_x @ x + self.W_h @ h_prev + self.b

ALL 4 GATE PRE-ACTIVATIONS in one expression. Two matmuls + one add. NumPy / PyTorch fuse this into one BLAS call internally.

EXECUTION STATE
z.shape = (4H,) = (1024,) - all 4 gates' pre-activations stacked
24i = sig(z[:H])

Slice rows 0..H-1 of z for the input gate, then sigmoid.

EXECUTION STATE
i.shape = (H,) = (256,)
25f = sig(z[H:2*H])

Forget gate (rows H..2H-1).

26g = np.tanh(z[2*H:3*H])

Candidate state (rows 2H..3H-1, with tanh).

27o = sig(z[3*H:4*H])

Output gate (rows 3H..4H-1).

29c = f * c_prev + i * g

Cell state update. Element-wise on (H,)-shaped vectors. Same line as §3.3 but now operates on 256-D vectors.

30h = o * np.tanh(c)

Hidden output - same idea.

31return h, c, (i, f, g, o)

Return new states + gate values for inspection.

35T, D, H = 30, 64, 256

30 cycles, 64-D input, 256-D hidden - matching the production backbone.

36cell = LSTMCellVector(input_size=D, hidden_size=H, seed=0)

Instantiate.

37seq = np.random.randn(T, D).astype(np.float32) * 0.5

Fake CNN-frontend output sequence.

38h, c = np.zeros(H, dtype=np.float32), np.zeros(H, dtype=np.float32)

Zero-init.

40for t in range(T):

Process all 30 cycles.

41h, c, gates = cell.step(seq[t], h, c)

Step. Re-binding (h, c) each iteration carries the state forward.

43print("h.shape :", h.shape)

Final hidden state.

EXECUTION STATE
Output = h.shape : (256,)
44print("c.shape :", c.shape)

Final cell state.

EXECUTION STATE
Output = c.shape : (256,)
45print("|h| :", np.linalg.norm(h).round(3))

L2 norm of the final hidden state.

EXECUTION STATE
Output (representative) = |h| : ~6.5
46print("|c| :", np.linalg.norm(c).round(3))

Cell-state norm; bounded by tanh non-linearity.

EXECUTION STATE
Output (representative) = |c| : ~10.2
16 lines without explanation
1import numpy as np
2
3# A vector LSTM cell — generalises the scalar version from §3.3.
4# Inputs and states are now vectors of length H; weights are matrices.
5class LSTMCellVector:
6    def __init__(self, input_size: int, hidden_size: int, seed: int = 0):
7        rng = np.random.default_rng(seed)
8        H, D = hidden_size, input_size
9        # PyTorch packs all 4 gates into one weight matrix of shape (4H, ·).
10        # Layout: rows 0..H-1 = i, H..2H-1 = f, 2H..3H-1 = g, 3H..4H-1 = o.
11        self.W_x = rng.standard_normal((4 * H, D)).astype(np.float32) / np.sqrt(D)
12        self.W_h = rng.standard_normal((4 * H, H)).astype(np.float32) / np.sqrt(H)
13        self.b   = np.zeros(4 * H, dtype=np.float32)
14        self.b[H:2*H] = 1.0          # forget bias = 1, like §3.3
15        self.H = H
16
17    def step(self, x: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray):
18        sig = lambda z: 1.0 / (1.0 + np.exp(-z))
19        H = self.H
20
21        # ALL 4 gates in one matmul
22        z = self.W_x @ x + self.W_h @ h_prev + self.b   # (4H,)
23
24        i = sig(z[:H])
25        f = sig(z[H:2*H])
26        g = np.tanh(z[2*H:3*H])
27        o = sig(z[3*H:4*H])
28
29        c = f * c_prev + i * g
30        h = o * np.tanh(c)
31        return h, c, (i, f, g, o)
32
33
34# ----- Run on a 30-cycle window -----
35np.random.seed(0)
36T, D, H = 30, 64, 256
37cell = LSTMCellVector(input_size=D, hidden_size=H, seed=0)
38seq  = np.random.randn(T, D).astype(np.float32) * 0.5
39h, c = np.zeros(H, dtype=np.float32), np.zeros(H, dtype=np.float32)
40
41for t in range(T):
42    h, c, gates = cell.step(seq[t], h, c)
43
44print("h.shape :", h.shape)         # (256,)
45print("c.shape :", c.shape)         # (256,)
46print("|h|     :", np.linalg.norm(h).round(3))
47print("|c|     :", np.linalg.norm(c).round(3))

PyTorch: nn.LSTMCell vs nn.LSTM

Two APIs do almost the same thing. nn.LSTMCell does ONE timestep at a time and is useful for custom recurrent loops. nn.LSTM processes a whole sequence and invokes the cuDNN fused kernel - always prefer it for normal sequence-to-sequence work.

Single-step nn.LSTMCell next to sequence-level nn.LSTM
🐍lstm_cell_vs_lstm.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

Layers.

5torch.manual_seed(0)

Determinism.

6cell = nn.LSTMCell(input_size=64, hidden_size=256)

Single-timestep API. Useful when you want explicit control over the recurrence (e.g., custom inputs at each step). For sequence processing prefer nn.LSTM.

8x_t = torch.randn(2, 64)

ONE timestep's input - shape (B, D).

9h = torch.zeros(2, 256)

Zero-init hidden state.

10c = torch.zeros(2, 256)

Zero-init cell state.

11h, c = cell(x_t, (h, c))

ONE LSTM step. Returns updated (h, c) pair.

12print("after cell h.shape:", tuple(h.shape))

Verify shape.

EXECUTION STATE
Output = after cell h.shape: (2, 256)
13print("after cell c.shape:", tuple(c.shape))

Same for c.

17seq = torch.randn(2, 30, 64)

Full sequence: (B, T, D).

18lstm = nn.LSTM(input_size=64, hidden_size=256, num_layers=1, batch_first=True)

Sequence-level API. Internally unrolls nn.LSTMCell across T cycles using a fused CUDA kernel. Always faster than a Python loop over LSTMCell.

19out, (h_n, c_n) = lstm(seq)

One call processes all 30 timesteps.

EXECUTION STATE
out.shape = (2, 30, 256) - per-timestep hidden outputs
h_n.shape = (num_layers, B, hidden) = (1, 2, 256)
21print("\nseq :", tuple(seq.shape))

Input shape.

EXECUTION STATE
Output = seq : (2, 30, 64)
22print("LSTM out :", tuple(out.shape))

Output sequence.

EXECUTION STATE
Output = LSTM out : (2, 30, 256)
23print("LSTM h_n :", tuple(h_n.shape))

Final hidden state.

EXECUTION STATE
Output = LSTM h_n : (1, 2, 256)
7 lines without explanation
1import torch
2import torch.nn as nn
3
4# nn.LSTMCell: ONE timestep at a time
5torch.manual_seed(0)
6cell = nn.LSTMCell(input_size=64, hidden_size=256)
7
8x_t   = torch.randn(2, 64)            # (B, D) at one timestep
9h     = torch.zeros(2, 256)
10c     = torch.zeros(2, 256)
11h, c  = cell(x_t, (h, c))
12print("after cell h.shape:", tuple(h.shape))    # (2, 256)
13print("after cell c.shape:", tuple(c.shape))    # (2, 256)
14
15
16# nn.LSTM: WHOLE sequence in one call
17seq   = torch.randn(2, 30, 64)        # (B, T, D)
18lstm  = nn.LSTM(input_size=64, hidden_size=256, num_layers=1, batch_first=True)
19out, (h_n, c_n) = lstm(seq)
20
21print("\nseq                :", tuple(seq.shape))
22print("LSTM out           :", tuple(out.shape))     # (2, 30, 256)
23print("LSTM h_n           :", tuple(h_n.shape))     # (1, 2, 256)

Parameter Accounting

Per LSTM layer with hidden size HH and input dim DD:

P=4H(D+H+1).P = 4H(D + H + 1).

For our backbone D=64,H=256D = 64, H = 256: 4256(64+256+1)=4256321=328,7044 \cdot 256 \cdot (64 + 256 + 1) = 4 \cdot 256 \cdot 321 = 328{,}704 parameters per direction per layer. Two directions x two layers = ~1.3M, but the second layer's input dim is 2H=512 (it receives the bidirectional concat from layer 1) so the accounting balloons to ~2.1M total. Section 9.4 walks through the exact numbers.

Two Math Pitfalls

Pitfall 1: Wrong gate ordering. If you slice the packed (4H,) pre-activation with the wrong order (PyTorch is i, f, g, o; some old papers use i, g, f, o) you get wildly wrong outputs. The forget bias=1 trick won't be at the right indices.
Pitfall 2: tanh vs sigmoid mix-up. The candidate state gtg_t uses tanh; the three gates use sigmoid. Mixing them up (e.g., applying sigmoid to gtg_t) makes the cell unable to write negative values into c, breaking the model without crashing.
The point. The LSTM cell is six lines of vector math. Production code packs the gate weights into one (4H, D) matrix and does the four gate pre-activations in a single matmul. Everything else - bidirectional, multi-layer - is wrapping around this six-line core.

Takeaway

  • Vector LSTM cell = scalar LSTM cell, vectorised. Same algebra; H-dim element-wise ops.
  • Per-layer parameter count is 4H(D+H+1). For our backbone this is ~328k per direction per layer.
  • nn.LSTM >> loop over nn.LSTMCell. The cuDNN kernel is 5-10× faster.
  • Pack the 4 gate weights into one matrix. One matmul; slice to recover individual gates.
Loading comments...