Chapter 3
17 min read
Section 11 of 121

Recurrent Networks & LSTM Cells

Mathematical Preliminaries

Working Memory: a Running Summary

When you read this sentence, you maintain a running summary of the words that came before. By the time you reach the period you have compressed a couple of dozen tokens into a small mental state — enough to disambiguate this, they, that on the next line. Your hippocampus does it for navigation; your auditory cortex does it for music; your frontal cortex does it for plans that span minutes. Recurrent neural networks are the engineering caricature of that capacity.

For RUL prediction the same idea applies to a different kind of memory: as a turbofan's sensor stream rolls past, the model needs to integrate slow degradation patterns over many cycles. A single spike at cycle 23 means little; a slow upward drift over cycles 5-30 means the engine is dying. Convolutions (Section 3.2) only see KK cycles at a time. RNNs carry a hidden state across the entire window.

The contract. An RNN consumes one timestep at a time and updates a hidden state. Read the sequence forward; at every step the hidden state summarises everything you have seen so far.

Vanilla RNN and Its Failure Mode

The simplest recurrent cell — the “vanilla RNN” — has one update equation:

ht  =  tanh ⁣(Wxxt+Whht1+b).h_t \;=\; \tanh\!\bigl(W_x \, x_t + W_h \, h_{t-1} + b\bigr).

The hidden state at step tt is a learned non-linear function of the current input and the previous hidden state. Repeat for TT steps and you have read the whole sequence. Beautiful, simple, and broken on long sequences.

The problem is the gradient. Backpropagating through a length-T RNN produces a product of T Jacobians, each containing a tanh\tanh' derivative bounded above by 1. With T=30T = 30 the gradient at the first cycle is attenuated by roughly 0.5301090.5^{30} \approx 10^{-9} in the worst case. The model effectively cannot learn dependencies more than ~10 cycles long. This is the famous vanishing gradient problem.

The fix. Hochreiter and Schmidhuber introduced the LSTM in 1997 specifically to address this. Replace the vanilla recurrence with an additive cell-state update plus three gates that let the network learn what to keep, what to forget, and what to expose. The gradient flows through the cell state almost unattenuated.

The LSTM Cell: Four Gates

An LSTM cell maintains two state vectors that propagate through time: ctRHc_t \in \mathbb{R}^{H} — the cell state, a long-term memory; and htRHh_t \in \mathbb{R}^{H} — the hidden output, what downstream layers see. Three gates and one candidate regulate the update:

SymbolEquationMeaning
iti_tσ(Wi[ht1,xt]+bi)\sigma(W_i [h_{t-1}, x_t] + b_i)Input gate — how much new info to admit
ftf_tσ(Wf[ht1,xt]+bf)\sigma(W_f [h_{t-1}, x_t] + b_f)Forget gate — how much old c to keep
gtg_ttanh(Wg[ht1,xt]+bg)\tanh(W_g [h_{t-1}, x_t] + b_g)Candidate update — what new content
oto_tσ(Wo[ht1,xt]+bo)\sigma(W_o [h_{t-1}, x_t] + b_o)Output gate — how much c to expose

Together they update the cell and hidden state via

ct=ftct1+itgt,ht=ottanh(ct).c_t = f_t \odot c_{t-1} + i_t \odot g_t, \qquad h_t = o_t \odot \tanh(c_t).

The element-wise \odot is critical: the cell state is updated by simple addition and pointwise scaling — no matrix multiplication in the recurrence loop. That is why gradients survive across long sequences: the partial derivative ct/ct1=ft\partial c_t / \partial c_{t-1} = f_t, which is element-wise close to 1 when the forget gate is open.

Interactive: Step Through 8 Timesteps

The trace below runs a scalar LSTM cell on the input pulse [0,0,0,1,1,1,0,0][0, 0, 0, 1, 1, 1, 0, 0]. Press play or step one timestep at a time. Watch the cell state ramp up when the pulse arrives, hold near 1 while the pulse is active, then decay slowly back toward zero after the pulse ends — the LSTM is remembering the recent input.

Loading LSTM step-trace…

The forget gate stays at f0.731f \approx 0.731 throughout because the input does not affect it (we hard-coded Wif=0W_{if} = 0). 73% of the cell state survives each step, which is what gives the post-pulse decay its characteristic half-life. In a real LSTM the forget gate is learned and can range from 0 (forget everything) to nearly 1 (remember forever).

Interactive: The Unrolled Diagram

Another way to see an RNN is to unroll it across time — draw one cell per timestep, with arrows showing how the hidden state flows from left to right. The animation below unrolls a generic recurrent cell over six timesteps; the same picture applies to any RNN variant including LSTMs.

RNN Unrolled Through Time

Each hidden state ht depends on the previous state ht-1

x1
h1
y1
x2
h2
y2
x3
h3
y3
x4
h4
...
y4
Input (xt)
Hidden state (ht)
Output (yt)
Sequential flow

Python: An LSTM Cell From Scratch

Twenty-five lines of NumPy and the entire algorithm is exposed. We write a scalar LSTM cell (one input dim, one hidden dim) so the gates are visible as actual numbers. Generalising to vector cells is trivial — replace each scalar weight with a matrix, each scalar multiplication with a matmul, and the algebra of the four gates is unchanged.

A scalar LSTM cell, 8-step pulse trace
🐍lstm_cell_micro.py
1import numpy as np

Need np.exp for sigmoid and np.tanh for the candidate state.

6class LSTMCellMicro:

Scalar LSTM cell - one input dim, one hidden dim. Every weight is a single number. Generalises to a real LSTM by replacing scalars with vectors and multiplications with matrix multiplications - the four-gate algebra is identical.

7def __init__(self):

Initialise four (input weight, bias) pairs, one per gate.

8self.W_ii, self.b_i = 0.5, 0.0

Input-gate weights. Decides how much of the candidate update g to admit.

EXECUTION STATE
i_t = sigmoid(W_ii * x + b_i) = Input gate in (0, 1)
9self.W_if, self.b_f = 0.0, 1.0

Forget-gate weights. W_if = 0 makes f insensitive to input. b_f = 1 sets f = sigmoid(1) ~ 0.731 - 73% of the cell state survives each step.

EXECUTION STATE
→ why bias 1? = Initialising forget bias to 1 (Jozefowicz et al. 2015) prevents premature memory erasure early in training. PyTorch does this by default.
10self.W_ig, self.b_g = 0.8, 0.0

Candidate-update weights. The candidate g_t is the new content the cell would write IF the input gate is open.

EXECUTION STATE
g_t = tanh(W_ig * x + b_g) = Candidate value in (-1, 1)
Example: x = 1 = g = tanh(0.8) = 0.664
Example: x = 0 = g = tanh(0) = 0
11self.W_io, self.b_o = 1.0, 0.0

Output-gate weights. Decides how much of the cell state to expose as the hidden output.

13def step(self, x, h_prev, c_prev):

Single timestep forward pass. Inputs: current x, previous (h, c). Outputs: new (h, c) plus the four gate values for inspection.

EXECUTION STATE
input: x = Current timestep's input
input: h_prev = Previous hidden state
input: c_prev = Previous cell state - the long-term memory
returns = (h, c, gates) - new states + gates dict
14sig = lambda z: 1.0 / (1.0 + np.exp(-z))

Sigmoid in three characters. Squashes any real number into (0, 1). All three gates use it; g uses tanh.

EXECUTION STATE
sigmoid(0) = 0.500
sigmoid(1) = 0.731
sigmoid(-1) = 0.269
15i = sig(self.W_ii * x + self.b_i)

Input gate. With our weights, i = sigmoid(0.5 * x). At x=0 it is 0.5; at x=1, sigmoid(0.5) = 0.622.

EXECUTION STATE
Example: x=0 = i = 0.500
Example: x=1 = i = 0.622
16f = sig(self.W_if * x + self.b_f)

Forget gate. W_if = 0 makes the gate constant = sigmoid(1) = 0.731.

EXECUTION STATE
f at every timestep = 0.731
17g = np.tanh(self.W_ig * x + self.b_g)

Candidate cell content. tanh keeps the candidate in (-1, 1).

EXECUTION STATE
Example: x=0 = g = 0
Example: x=1 = g = 0.664
18o = sig(self.W_io * x + self.b_o)

Output gate. With our weights, o = sigmoid(x).

19c = f * c_prev + i * g

THE CRUCIAL LINE. The cell state is updated by element-wise multiplying the previous cell state by the forget gate (how much to keep) and ADDING the input gate times the candidate (how much new info to write). No matrix multiply in the recurrence loop - that is why gradients survive across long sequences.

EXECUTION STATE
Example: t=3, c_prev=0 = c = 0.731 * 0 + 0.622 * 0.664 = 0.413
Example: t=4, c_prev=0.413 = c = 0.731 * 0.413 + 0.622 * 0.664 = 0.716
Example: t=6, c_prev=0.936, x=0 = c = 0.731 * 0.936 + 0.500 * 0 = 0.685 - decays slowly!
→ memory effect = After the pulse ends (t=5), c is still 0.685 at t=6 and 0.500 at t=7. The LSTM remembers.
20h = o * np.tanh(c)

Hidden output. tanh re-squashes the cell state into (-1, 1); the output gate masks it. The hidden output is what downstream layers see.

EXECUTION STATE
Example: t=4, c=0.716 = h = 0.731 * tanh(0.716) = 0.731 * 0.614 = 0.449
21return h, c, (i, f, g, o)

Return the new states plus the gate values so the trace can inspect them.

25cell = LSTMCellMicro()

Instantiate the cell.

26seq = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]

A square-pulse input - three zeros, three ones, three zeros. Tests whether the LSTM can RAMP UP its cell state when the pulse arrives and HOLD ON to it after it ends.

EXECUTION STATE
seq = [0, 0, 0, 1, 1, 1, 0, 0]
27h, c = 0.0, 0.0

Zero-initialised hidden + cell state.

28for t, x in enumerate(seq):

Loop over the 8 timesteps.

LOOP TRACE · 6 iterations
t=0..2 (x=0)
behaviour = All gates idle, c stays at 0
t=3 (x=1)
behaviour = Pulse arrives - c jumps to 0.413
t=4 (x=1)
behaviour = c climbs to 0.716
t=5 (x=1)
behaviour = c reaches 0.936 - close to saturation
t=6 (x=0)
behaviour = Pulse ends - c decays to 0.685 (kept 73%)
t=7 (x=0)
behaviour = c continues to decay to 0.500
29h, c, (i, f, g, o) = cell.step(x, h, c)

One forward step. Re-binding h and c carries the memory forward across timesteps - that is the 'recurrent' in 'recurrent neural network'.

30print(...)

Pretty-print the per-timestep state. Numerical values match the comments below.

EXECUTION STATE
Output t=3 (pulse onset) = i=0.622 f=0.731 g=+0.664 o=0.731 c=+0.413 h=+0.286
Output t=5 (pulse peak) = c=+0.936 h=+0.536 - cell state nearly saturated
Output t=7 (memory persists) = c=+0.500 - half the pulse-peak value retained two steps later
13 lines without explanation
1import numpy as np
2
3# ----- A scalar LSTM cell, written from scratch -----
4# input_size = hidden_size = 1, so each "matrix" is just a number.
5# Generalises trivially: replace scalars with matmuls and you have nn.LSTMCell.
6class LSTMCellMicro:
7    def __init__(self):
8        self.W_ii, self.b_i = 0.5, 0.0
9        self.W_if, self.b_f = 0.0, 1.0     # forget bias = 1: start "remembering"
10        self.W_ig, self.b_g = 0.8, 0.0
11        self.W_io, self.b_o = 1.0, 0.0
12
13    def step(self, x: float, h_prev: float, c_prev: float):
14        sig = lambda z: 1.0 / (1.0 + np.exp(-z))
15        i = sig(self.W_ii * x + self.b_i)        # input gate in (0, 1)
16        f = sig(self.W_if * x + self.b_f)        # forget gate in (0, 1)
17        g = np.tanh(self.W_ig * x + self.b_g)    # candidate in (-1, 1)
18        o = sig(self.W_io * x + self.b_o)        # output gate in (0, 1)
19        c = f * c_prev + i * g                   # CELL STATE UPDATE
20        h = o * np.tanh(c)                       # HIDDEN OUTPUT
21        return h, c, (i, f, g, o)
22
23
24# ----- Run on a square pulse -----
25cell = LSTMCellMicro()
26seq = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]
27h, c = 0.0, 0.0
28for t, x in enumerate(seq):
29    h, c, (i, f, g, o) = cell.step(x, h, c)
30    print(f"t={t}  x={x:+.0f}  i={i:.3f}  f={f:.3f}  g={g:+.3f}  "
31          f"o={o:.3f}  c={c:+.3f}  h={h:+.3f}")
32
33# t=3  x=+1  i=0.622  f=0.731  g=+0.664  o=0.731  c=+0.413  h=+0.286
34# t=5  x=+1  i=0.622  f=0.731  g=+0.664  o=0.731  c=+0.936  h=+0.536
35# t=7  x=+0  i=0.500  f=0.731  g=+0.000  o=0.500  c=+0.500  h=+0.231

The numbers tell the story

Read down the cell-state column: 0 → 0 → 0 → 0.413 → 0.716 → 0.936 → 0.685 → 0.500. The cell state builds up while the pulse is active and decays gracefully after it ends. That is what working memory looks like in arithmetic.

PyTorch: nn.LSTM in Six Lines

Production code never writes the cell from scratch. PyTorch's nn.LSTM wraps a CUDA-optimised batched implementation with multi-layer and bidirectional support built in.

Bidirectional, two-layer nn.LSTM with batch_first=True
🐍nn_lstm.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

nn.LSTM lives here.

5torch.manual_seed(0)

Deterministic weight initialisation.

6B, T, F = 2, 30, 17

Same shape convention from §3.1 and §3.2.

8rnn = nn.LSTM(...)

PyTorch's batched, parallelised, multi-layer, bidirectional LSTM. Built on cuDNN when CUDA is available - typically 5-10x faster than a hand-written cell.

9input_size=17

Per-cycle input dim. 17 sensors per cycle on C-MAPSS; matches our F.

EXECUTION STATE
→ relation to §3.2 = If you stack a Conv1D BEFORE the LSTM with out_channels=64, then input_size=64 here.
10hidden_size=256

Hidden / cell state dim. Paper uses 256.

EXECUTION STATE
hidden_size = 256
11num_layers=2

Stack two LSTM layers. Output of layer 1 (post bidirectional concat) becomes input of layer 2.

12bidirectional=True

Run two LSTMs in parallel - forward + backward - concatenate hidden outputs at every cycle. Output dim doubles to 2 * 256 = 512.

EXECUTION STATE
→ caveat = Cannot be used for streaming inference (requires the full future). For RUL on a complete window this is fine.
13batch_first=True

Tells PyTorch to expect (B, T, F) and emit (B, T, output_dim). Default is False - which expects (T, B, F). Forgetting this is the #1 LSTM bug.

EXECUTION STATE
batch_first=True = Input/output shape (B, T, F)
batch_first=False = Input/output shape (T, B, F) - the legacy default
17x = torch.randn(B, T, F)

Fake batch in (B, T, F) order.

EXECUTION STATE
x.shape = torch.Size([2, 30, 17])
18out, (h_n, c_n) = rnn(x)

Forward pass returns a 2-tuple: SEQUENCE of hidden outputs + FINAL hidden / cell states. Most downstream models use `out`.

EXECUTION STATE
out = (2, 30, 512) - full sequence of hidden outputs from the LAST layer
h_n = (num_layers * num_directions, B, hidden) = (4, 2, 256)
c_n = Same shape as h_n. Final cell states.
20print("input x :", tuple(x.shape))

Verify input shape.

EXECUTION STATE
Output = input x : (2, 30, 17)
21print("output out :", tuple(out.shape))

Output sequence shape. Time axis preserved; feature dim = 2 * 256 = 512 (forward concat with backward).

EXECUTION STATE
Output = output out : (2, 30, 512)
22print("hidden h_n :", tuple(h_n.shape))

Final hidden state - 4 entries because 2 layers x 2 directions.

EXECUTION STATE
Output = hidden h_n : (4, 2, 256)
23print("cell c_n :", tuple(c_n.shape))

Final cell state - same shape as h_n.

EXECUTION STATE
Output = cell c_n : (4, 2, 256)
24print("# params :", ...)

Parameter count. ~2.1M parameters - the LSTM is the largest single component of the backbone we will build in Chapter 9.

EXECUTION STATE
Output = # params : 2,140,160
8 lines without explanation
1import torch
2import torch.nn as nn
3
4# ----- A two-layer bidirectional LSTM, batch-first -----
5torch.manual_seed(0)
6B, T, F = 2, 30, 17                # 2 engines, 30 cycles, 17 sensors
7
8rnn = nn.LSTM(
9    input_size=17,
10    hidden_size=256,
11    num_layers=2,
12    bidirectional=True,
13    batch_first=True,              # CRITICAL: default is False!
14)
15
16# Forward pass on a fake batch
17x = torch.randn(B, T, F)
18out, (h_n, c_n) = rnn(x)
19
20print("input  x   :", tuple(x.shape))     # (2, 30, 17)
21print("output out :", tuple(out.shape))    # (2, 30, 512)
22print("hidden h_n :", tuple(h_n.shape))    # (4, 2, 256)
23print("cell   c_n :", tuple(c_n.shape))    # (4, 2, 256)
24print("# params  :", sum(p.numel() for p in rnn.parameters()))
25# # params  : 2,140,160
Production shortcut. The chunk nn.LSTM(input_size=17, hidden_size=256, num_layers=2, bidirectional=True, batch_first=True) is verbatim what the backbone in Chapter 9 uses. ~2.1M parameters — the largest single component of the network.

Recurrent Networks Beyond RUL

DomainSequenceHidden state capturesFamous architecture
RUL (this book)30 cycles of 17 sensorsCumulative degradationCNN-BiLSTM-Attention
Language modelsSubword tokensSentence-level meaningGPT-1 / ELMo
Speech recognitionAudio framesPhoneme contextDeepSpeech 2
Machine translationSource tokensSentence representationseq2seq + attention
Music generationAudio samplesTonal / rhythmic motifWaveNet, SampleRNN
Reinforcement learningGame framesBelief over hidden stateDRQN, R2D2
Time-series forecastingHourly observationsTrend + seasonalityDeepAR, encoder-decoder LSTM
Medical event predictionEHR codes / vitalsPatient trajectoryDoctor AI, RETAIN

Every row shipped a state-of-the-art result at some point in the last decade. Transformers (Section 3.4) have since taken over large-scale text and vision, but LSTMs remain the right tool for small-batch, low-latency, low-data settings — including most prognostic problems with under a million sensor samples.

The Three Pitfalls

Pitfall 1: batch_first=False by default. The single most common LSTM bug. PyTorch's historical default expects (T,B,F)(T, B, F), but every other layer in the book uses (B,T,F)(B, T, F). Always pass batch_first=True.
Pitfall 2: Stale hidden state across batches. If you re-feed h_n from one batch into the next without .detach(), autograd will accumulate the entire computation graph across batches and either OOM or silently corrupt gradients (the “BPTT-through-batches” bug).
Pitfall 3: Bidirectional + streaming inference. bidirectional=True requires the full future to compute the backward direction. You cannot use a bidirectional LSTM in a streaming setting where new cycles arrive one at a time. For end-of-window RUL prediction (our setting) it is fine.
The point. An LSTM is a learnable, gated working memory. The gates are the engineering trick that solves vanishing gradients; the cell state is the long-term memory; the hidden output is what the rest of the model sees.

Takeaway

  • RNNs read sequences and update a hidden state. They add temporal modelling on top of whatever frontend you put in place (Conv1D in this book).
  • Vanilla RNNs vanish. Length-30 sequences are too long for tanh-based recurrences; gradients shrink to nothing.
  • LSTMs fix it with three gates and a candidate. ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t and ht=ottanh(ct)h_t = o_t \odot \tanh(c_t). Gradients flow through ctc_t almost unattenuated.
  • PyTorch's nn.LSTM is six arguments. input_size, hidden_size, num_layers, bidirectional, batch_first, dropout — the entire backbone of Chapter 9 fits on one line.
  • Always set batch_first=True. The default is (T,B,F)(T, B, F) and breaks every other layer's shape contract.
Loading comments...