Section 3.3 introduced the LSTM cell as a working-memory analogue: three gates (input / forget / output) plus a candidate state, with the cell-state update ct=ft⊙ct−1+it⊙gt being the magic line that lets gradients flow through long sequences. We worked through a SCALAR cell to keep the math visible. This section shows the same algebra in vector form - the production-grade calculation that nn.LSTM does internally.
From Scalar Cell to Vector Cell
With xt∈RD and ht,ct∈RH, every gate becomes a linear-plus-non-linearity from (xt,ht−1):
it=σ(Wixxt+Wihht−1+bi)
ft=σ(Wfxxt+Wfhht−1+bf)
gt=tanh(Wgxxt+Wghht−1+bg)
ot=σ(Woxxt+Wohht−1+bo)
Then ct=ft⊙ct−1+it⊙gt and ht=ot⊙tanh(ct). The only difference from §3.3 is that everything is now a coordinate-wise vector op of length H, and the W matrices have shape (H, D) or (H, H).
The packing trick. PyTorch concatenates the four gate-weight matrices into one tall (4H, D) matrix and the four bias vectors into one (4H,) vector. This lets the four gate pre-activations be computed in a SINGLE matmul call. Slicing then peels them apart for the gate-specific non-linearities.
Interactive: Step-Trace (Recap)
The 8-cycle pulse trace from §3.3, reproduced. The vector form does the same thing per coordinate that the scalar form did once.
Loading LSTM step-trace…
Python: Vector LSTM Cell From Scratch
The class below packs all 4 gate weights into one matrix and does the gate computation in one matmul (matching nn.LSTM's internal layout). Run on a 30-cycle, 64-D input with hidden size 256.
Vector LSTM cell with packed weight matrix
🐍lstm_cell_vector.py
Explanation(31)
Code(47)
1import numpy as np
Standard alias.
5class LSTMCellVector:
Generalises §3.3's scalar LSTMCellMicro to vector states. Same 4-gate algebra; matrix multiplications instead of scalar.
ALL 4 gate-input weights packed into ONE matrix of shape (4H, D). PyTorch does this internally - one matmul faster than four. Init with std=1/sqrt(D) - poor man's Xavier.
38h, c = np.zeros(H, dtype=np.float32), np.zeros(H, dtype=np.float32)
Zero-init.
40for t in range(T):
Process all 30 cycles.
41h, c, gates = cell.step(seq[t], h, c)
Step. Re-binding (h, c) each iteration carries the state forward.
43print("h.shape :", h.shape)
Final hidden state.
EXECUTION STATE
Output = h.shape : (256,)
44print("c.shape :", c.shape)
Final cell state.
EXECUTION STATE
Output = c.shape : (256,)
45print("|h| :", np.linalg.norm(h).round(3))
L2 norm of the final hidden state.
EXECUTION STATE
Output (representative) = |h| : ~6.5
46print("|c| :", np.linalg.norm(c).round(3))
Cell-state norm; bounded by tanh non-linearity.
EXECUTION STATE
Output (representative) = |c| : ~10.2
16 lines without explanation
1import numpy as np
23# A vector LSTM cell — generalises the scalar version from §3.3.4# Inputs and states are now vectors of length H; weights are matrices.5classLSTMCellVector:6def__init__(self, input_size:int, hidden_size:int, seed:int=0):7 rng = np.random.default_rng(seed)8 H, D = hidden_size, input_size
9# PyTorch packs all 4 gates into one weight matrix of shape (4H, ·).10# Layout: rows 0..H-1 = i, H..2H-1 = f, 2H..3H-1 = g, 3H..4H-1 = o.11 self.W_x = rng.standard_normal((4* H, D)).astype(np.float32)/ np.sqrt(D)12 self.W_h = rng.standard_normal((4* H, H)).astype(np.float32)/ np.sqrt(H)13 self.b = np.zeros(4* H, dtype=np.float32)14 self.b[H:2*H]=1.0# forget bias = 1, like §3.315 self.H = H
1617defstep(self, x: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray):18 sig =lambda z:1.0/(1.0+ np.exp(-z))19 H = self.H
2021# ALL 4 gates in one matmul22 z = self.W_x @ x + self.W_h @ h_prev + self.b # (4H,)2324 i = sig(z[:H])25 f = sig(z[H:2*H])26 g = np.tanh(z[2*H:3*H])27 o = sig(z[3*H:4*H])2829 c = f * c_prev + i * g
30 h = o * np.tanh(c)31return h, c,(i, f, g, o)323334# ----- Run on a 30-cycle window -----35np.random.seed(0)36T, D, H =30,64,25637cell = LSTMCellVector(input_size=D, hidden_size=H, seed=0)38seq = np.random.randn(T, D).astype(np.float32)*0.539h, c = np.zeros(H, dtype=np.float32), np.zeros(H, dtype=np.float32)4041for t inrange(T):42 h, c, gates = cell.step(seq[t], h, c)4344print("h.shape :", h.shape)# (256,)45print("c.shape :", c.shape)# (256,)46print("|h| :", np.linalg.norm(h).round(3))47print("|c| :", np.linalg.norm(c).round(3))
PyTorch: nn.LSTMCell vs nn.LSTM
Two APIs do almost the same thing. nn.LSTMCell does ONE timestep at a time and is useful for custom recurrent loops. nn.LSTM processes a whole sequence and invokes the cuDNN fused kernel - always prefer it for normal sequence-to-sequence work.
Single-step nn.LSTMCell next to sequence-level nn.LSTM
Single-timestep API. Useful when you want explicit control over the recurrence (e.g., custom inputs at each step). For sequence processing prefer nn.LSTM.
1import torch
2import torch.nn as nn
34# nn.LSTMCell: ONE timestep at a time5torch.manual_seed(0)6cell = nn.LSTMCell(input_size=64, hidden_size=256)78x_t = torch.randn(2,64)# (B, D) at one timestep9h = torch.zeros(2,256)10c = torch.zeros(2,256)11h, c = cell(x_t,(h, c))12print("after cell h.shape:",tuple(h.shape))# (2, 256)13print("after cell c.shape:",tuple(c.shape))# (2, 256)141516# nn.LSTM: WHOLE sequence in one call17seq = torch.randn(2,30,64)# (B, T, D)18lstm = nn.LSTM(input_size=64, hidden_size=256, num_layers=1, batch_first=True)19out,(h_n, c_n)= lstm(seq)2021print("\nseq :",tuple(seq.shape))22print("LSTM out :",tuple(out.shape))# (2, 30, 256)23print("LSTM h_n :",tuple(h_n.shape))# (1, 2, 256)
Parameter Accounting
Per LSTM layer with hidden size H and input dim D:
P=4H(D+H+1).
For our backbone D=64,H=256: 4⋅256⋅(64+256+1)=4⋅256⋅321=328,704 parameters per direction per layer. Two directions x two layers = ~1.3M, but the second layer's input dim is 2H=512 (it receives the bidirectional concat from layer 1) so the accounting balloons to ~2.1M total. Section 9.4 walks through the exact numbers.
Two Math Pitfalls
Pitfall 1: Wrong gate ordering. If you slice the packed (4H,) pre-activation with the wrong order (PyTorch is i, f, g, o; some old papers use i, g, f, o) you get wildly wrong outputs. The forget bias=1 trick won't be at the right indices.
Pitfall 2: tanh vs sigmoid mix-up. The candidate state gt uses tanh; the three gates use sigmoid. Mixing them up (e.g., applying sigmoid to gt) makes the cell unable to write negative values into c, breaking the model without crashing.
The point. The LSTM cell is six lines of vector math. Production code packs the gate weights into one (4H, D) matrix and does the four gate pre-activations in a single matmul. Everything else - bidirectional, multi-layer - is wrapping around this six-line core.