Chapter 15
28 min read
Section 89 of 178

Long Short-Term Memory

LSTM and GRU

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the LSTM architecture including its cell state and gating mechanisms
  2. Explain the purpose of each gate (forget, input, output) and how they control information flow
  3. Write the complete LSTM equations and understand what each component computes
  4. Explain why LSTM solves the vanishing gradient problem through additive cell state updates
  5. Visualize gradient flow through the "constant error carousel"
  6. Develop intuition for when and how LSTM learns to remember or forget information
Why This Matters: LSTM was a breakthrough that made sequence modeling practical. Before LSTM, training RNNs on sequences longer than 10-20 steps was nearly impossible. After LSTM, neural networks could learn dependencies spanning hundreds of timesteps. This enabled transformative applications in machine translation, speech recognition, and language modeling. Understanding LSTM deeply is essential because: (1) it remains widely used in production systems, (2) its gating mechanisms inspired later architectures including Transformers, and (3) it beautifully illustrates how architectural choices can solve fundamental training problems.

The Story Behind LSTM

In 1997, Sepp Hochreiter and Jürgen Schmidhuber published a paper that would change the course of deep learning: "Long Short-Term Memory." The title itself captures the key insight: they wanted to enable neural networks to have both long-term memory (information that persists over many timesteps) and short-term flexibility (the ability to update and use that information appropriately).

The Problem They Were Solving

Recall from the previous section that vanilla RNNs suffer from the vanishing gradient problem. The gradient from time TT to time tt involves a product of Jacobians:

hTht=k=t+1TWhhTdiag(tanh(zk))\frac{\partial h_T}{\partial h_t} = \prod_{k=t+1}^{T} W_{hh}^T \cdot \text{diag}(\tanh'(z_k))

Each factor in this product is typically less than 1, so the gradient decays exponentially. Hochreiter and Schmidhuber's key insight was: what if we could create a pathway where the gradient multiplication factor is exactly 1?

The Breakthrough Idea

The solution came from a simple but profound observation: addition doesn't shrink gradients the way multiplication does. If we update a quantity by adding to it rather than multiplying:

Ct=Ct1+(something)C_t = C_{t-1} + \text{(something)}

Then the gradient CtCt1=1\frac{\partial C_t}{\partial C_{t-1}} = 1. The gradient flows through unchanged! This is the foundation of the LSTM cell state.

The Constant Error Carousel

Hochreiter and Schmidhuber called this the "Constant Error Carousel" (CEC). When the network learns to keep the cell state unchanged (Ct=Ct1C_t = C_{t-1}), errors (gradients) can flow backward through time without decay. The network learns when to preserve and when to update its memory.

The Key Innovation: The Cell State

The LSTM introduces a new quantity called the cell state, denoted CtC_t. This cell state runs through the entire sequence like a "memory highway," carrying information with minimal modification.

Two Parallel Paths

Unlike vanilla RNNs which have only the hidden state hth_t, LSTM maintains two parallel quantities:

QuantitySymbolRoleUpdate Mechanism
Cell StateCₜLong-term memory storageAdditive updates (preserves gradients)
Hidden StatehₜShort-term output/representationMultiplicative gating of cell state

The cell state CtC_t is the key to LSTM's success. It can remain unchanged across many timesteps if needed, or it can be selectively updated. The hidden state hth_t is derived from the cell state but filtered through an "output gate"—so the LSTM can remember things internally that it doesn't expose in its outputs.

Quick Check

What is the key difference between the cell state update in LSTM compared to the hidden state update in vanilla RNNs?


Interactive: LSTM Architecture

Explore the LSTM cell architecture interactively. Click on each gate to learn about its function, or use the "Animate Flow" button to see how information moves through the cell.

LSTM Cell Architecture

Cell State CtCt-1Ct×Forgetσft+Addσit×tanh&Ctilde;ttanhσot×ht[ht-1, xt]ht-1xtLSTM Cell

LSTM Architecture

Click on any gate in the diagram to learn about its function, or use the "Animate Flow" button to see how information flows through the LSTM cell.

Information Flow Order

1
Forget Gate
2
Input Gate
3
Cell State Update
4
Output Gate

The LSTM Key Insight: The cell state (top horizontal line) acts as a "memory highway" that allows information to flow unchanged through time. The gates control what gets added to or removed from this highway. Because the cell state update uses addition rather than multiplication, gradients can flow backward through many timesteps without vanishing.


The Four Components of LSTM

An LSTM cell has four key components: three gates (forget, input, output) and a candidate cell state generator. Each plays a distinct role in controlling information flow.

1. The Forget Gate (ft)

The forget gate decides what information from the previous cell state should be discarded. It looks at the previous hidden state ht1h_{t-1} and current input xtx_t and outputs values between 0 and 1 for each element of Ct1C_{t-1}.

ft=σ(Wf[ht1,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
Gate ValueMeaningEffect on Cₜ₋₁
fₜ ≈ 0"Forget this"Information is discarded
fₜ ≈ 1"Remember this"Information is preserved
fₜ = 0.5"Partially remember"Information is attenuated by 50%

Why 'Forget' Gate?

The name is slightly misleading. The forget gate controls what to keep, not what to forget. When ft=1f_t = 1, everything is kept; when ft=0f_t = 0, everything is forgotten. The gate "forgets" when it outputs low values.

2. The Input Gate (it)

The input gate decides which new information to add to the cell state. Like the forget gate, it uses a sigmoid to output values between 0 and 1:

it=σ(Wi[ht1,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)

3. The Candidate Cell State (C̃t)

While the input gate decides how much new information to add, the candidate cell state determines what new information to potentially add. It uses tanh to produce values between -1 and 1:

C~t=tanh(WC[ht1,xt]+bC)\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)

Why tanh for the candidate?

The candidate uses tanh (range: [-1, 1]) rather than sigmoid (range: [0, 1]) because we want to be able to both increase and decrease the cell state. Positive values add to the cell state, negative values subtract from it.

4. The Cell State Update

Now we can update the cell state. This is where the magic happens:

Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t

This equation says: take the old cell state, forget some parts (multiply by ftf_t), and add new information (scaled by iti_t). The \odot symbol represents element-wise (Hadamard) multiplication.

Critical Observation

Notice the plus sign between the two terms! This is what enables gradients to flow backward through time. When ft1f_t \approx 1 and it0i_t \approx 0, we have CtCt1C_t \approx C_{t-1}, and the gradient passes through unchanged.

5. The Output Gate (ot)

Finally, the output gate controls what part of the cell state is exposed as the hidden state output:

ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)

The hidden state output is computed by applying tanh to the cell state (to push values to [-1, 1]) and then filtering through the output gate:

ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)

Why Filter the Output?

The output gate allows the LSTM to "hide" internal memories. The network can remember something in CtC_t without revealing it in hth_t. This is useful when information is needed later but shouldn't influence immediate outputs.

Interactive: Gate Explorer

Use the sliders below to adjust the inputs and observe how each gate responds. Watch how the gates work together to update the cell state and produce the hidden state output.

Interactive Gate Explorer

Adjust the inputs and observe how each gate responds. The gates work together to update the cell state and produce the hidden state output.

The output from the previous timestep

The input at the current timestep

The long-term memory from previous timestep

Forget Gate ft

0.565

Some of Ct-1 will be partially kept

σ(0.260) = 0.565

Input Gate it

0.545

Some new info will be stored

σ(0.180) = 0.545

Candidate &Ctilde;t

0.291

New candidate memory value (range: -1 to 1)

tanh(0.300) = 0.291

Output Gate ot

0.613

Some of cell state will be output

σ(0.460) = 0.613

Cell State Update

Ct=0.565×0.800+0.545×0.291=0.610
ht=0.613×tanh(0.610)=0.334

New Cell State Ct

0.610

The updated long-term memory. Change: -0.190 from Ct-1

New Hidden State ht

0.334

The output (and input to next timestep)

Try this: Set the forget gate input to favor keeping (f > 0.9) and the input gate to favor ignoring new info (i < 0.1). Watch how the cell state Ct stays close to Ct-1. This is how LSTM "remembers" information across many timesteps!


Complete Mathematical Formulation

Let's bring all the equations together. Given input xtx_t at time tt, previous hidden state ht1h_{t-1}, and previous cell state Ct1C_{t-1}, the LSTM computes:

Gate Computations

ft=σ(Wf[ht1,xt]+bf)(Forget gate)it=σ(Wi[ht1,xt]+bi)(Input gate)C~t=tanh(WC[ht1,xt]+bC)(Candidate cell state)ot=σ(Wo[ht1,xt]+bo)(Output gate)\begin{aligned} f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) && \text{(Forget gate)} \\ i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) && \text{(Input gate)} \\ \tilde{C}_t &= \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) && \text{(Candidate cell state)} \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) && \text{(Output gate)} \end{aligned}

State Updates

Ct=ftCt1+itC~t(Cell state update)ht=ottanh(Ct)(Hidden state output)\begin{aligned} C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t && \text{(Cell state update)} \\ h_t &= o_t \odot \tanh(C_t) && \text{(Hidden state output)} \end{aligned}

Notation Summary

SymbolDescriptionDimensions
xₜInput at timestep td (input dimension)
hₜHidden state (output)n (hidden dimension)
CₜCell state (internal memory)n (hidden dimension)
fₜ, iₜ, oₜGate activationsn (hidden dimension)
C̃ₜCandidate cell staten (hidden dimension)
Wf, Wᵢ, Wc, WₒWeight matrices(n) × (n + d)
bf, bᵢ, bc, bₒBias vectorsn
σSigmoid function (outputs 0-1)
tanhHyperbolic tangent (outputs -1 to 1)
Element-wise multiplication
LSTM Cell Implementation from Scratch
🐍lstm_cell.py
4LSTM Cell Class

We implement the LSTM cell as an nn.Module, which manages parameters and enables integration with PyTorch's autograd system.

14Combined Weight Matrix

Instead of four separate weight matrices (W_f, W_i, W_C, W_o), we use a single matrix of size (4*hidden_size, input_size). This is more memory-efficient and allows a single matrix multiplication.

23Forget Gate Bias Initialization

A common practice is to initialize the forget gate bias to 1, which makes the initial forget gate output close to 1 (after sigmoid). This encourages the network to remember information by default early in training.

55Gate Computation

We compute all four gates with a single matrix multiplication and then split the result. This is mathematically equivalent to computing each gate separately but more efficient.

63Gate Activations

The three gates (f, i, o) use sigmoid to produce values in [0, 1] for gating. The candidate cell state uses tanh to produce values in [-1, 1].

69Cell State Update

This is the critical equation: C_t = f_t * C_{t-1} + i_t * C_tilde_t. The addition (not just multiplication) is what allows gradients to flow through time.

72Hidden State Output

The hidden state is the cell state (pushed through tanh) filtered by the output gate. This controls what information is exposed as the LSTM's output.

87 lines without explanation
1import torch
2import torch.nn as nn
3
4class LSTMCell(nn.Module):
5    """LSTM cell implemented from first principles."""
6
7    def __init__(self, input_size: int, hidden_size: int):
8        super().__init__()
9        self.input_size = input_size
10        self.hidden_size = hidden_size
11
12        # Combined weight matrix for all four gates
13        # This is more efficient than separate matrices
14        # Computes [f, i, C_tilde, o] in one matrix multiplication
15        self.weight_ih = nn.Parameter(
16            torch.randn(4 * hidden_size, input_size) / (input_size ** 0.5)
17        )
18        self.weight_hh = nn.Parameter(
19            torch.randn(4 * hidden_size, hidden_size) / (hidden_size ** 0.5)
20        )
21        self.bias = nn.Parameter(torch.zeros(4 * hidden_size))
22
23        # Initialize forget gate bias to 1 (common practice)
24        # This encourages the network to remember by default
25        nn.init.ones_(self.bias[0:hidden_size])
26
27    def forward(
28        self,
29        x: torch.Tensor,
30        hx: tuple[torch.Tensor, torch.Tensor] | None = None
31    ) -> tuple[torch.Tensor, torch.Tensor]:
32        """
33        Args:
34            x: Input tensor of shape (batch, input_size)
35            hx: Tuple of (h_prev, c_prev), each (batch, hidden_size)
36
37        Returns:
38            Tuple of (h_t, c_t)
39        """
40        batch_size = x.size(0)
41
42        # Initialize hidden and cell state if not provided
43        if hx is None:
44            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
45            c_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
46        else:
47            h_prev, c_prev = hx
48
49        # Compute all gates in one go
50        # gates = W_ih @ x + W_hh @ h_prev + b
51        gates = (x @ self.weight_ih.T +
52                 h_prev @ self.weight_hh.T +
53                 self.bias)
54
55        # Split into four gates
56        f_gate, i_gate, c_tilde, o_gate = gates.chunk(4, dim=1)
57
58        # Apply activations
59        f_t = torch.sigmoid(f_gate)  # Forget gate
60        i_t = torch.sigmoid(i_gate)  # Input gate
61        c_tilde_t = torch.tanh(c_tilde)  # Candidate cell state
62        o_t = torch.sigmoid(o_gate)  # Output gate
63
64        # Cell state update: C_t = f_t * C_{t-1} + i_t * C_tilde_t
65        c_t = f_t * c_prev + i_t * c_tilde_t
66
67        # Hidden state output: h_t = o_t * tanh(C_t)
68        h_t = o_t * torch.tanh(c_t)
69
70        return h_t, c_t
71
72
73# Example usage
74def example_forward_pass():
75    batch_size, seq_len = 32, 50
76    input_size, hidden_size = 64, 128
77
78    lstm_cell = LSTMCell(input_size, hidden_size)
79
80    # Create random input sequence
81    x_sequence = torch.randn(seq_len, batch_size, input_size)
82
83    # Process sequence
84    outputs = []
85    h, c = None, None
86    for t in range(seq_len):
87        h, c = lstm_cell(x_sequence[t], (h, c) if h is not None else None)
88        outputs.append(h)
89
90    output_sequence = torch.stack(outputs, dim=0)
91    print(f"Output shape: {output_sequence.shape}")
92    # Output shape: torch.Size([50, 32, 128])
93
94    return output_sequence

Why LSTM Solves the Vanishing Gradient Problem

Now let's analyze why LSTM's architecture solves the vanishing gradient problem. The key is understanding the gradient of the cell state.

Gradient Through the Cell State

Consider the cell state update equation:

Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t

Taking the partial derivative with respect to Ct1C_{t-1} (treating the gates as constants for this analysis):

CtCt1=ft\frac{\partial C_t}{\partial C_{t-1}} = f_t

This is remarkable! The gradient from CtC_t to Ct1C_{t-1} is simply the forget gate value ftf_t. If the network learns to set ft1f_t \approx 1, the gradient passes through unchanged.

Gradient Over Many Timesteps

For a loss L\mathcal{L} at time TT, the gradient with respect to an early cell state C1C_1 is:

LC1=LCTt=2Tft\frac{\partial \mathcal{L}}{\partial C_1} = \frac{\partial \mathcal{L}}{\partial C_T} \cdot \prod_{t=2}^{T} f_t

Compare this to vanilla RNN where the product involves WhhW_{hh} and \tanh&apos; terms. In LSTM:

NetworkGradient ProductTypical ValueAfter 50 Steps
Vanilla RNN∏(Wₕₕ · tanh')≈ 0.8 per step≈ 10⁻⁵
LSTM∏(fₜ)≈ 0.95 per step≈ 0.08

The Power of Learning to Remember

The forget gate is not fixed—it's learned. When the network needs to remember something for a long time, it can learn to set ft1f_t \approx 1 for those memory cells. When information is no longer needed, it can set ft0f_t \approx 0 to clear the memory. This adaptive memory management is what makes LSTM so powerful.

Quick Check

If the forget gate outputs fₜ = 0.98 at every timestep, what fraction of the original gradient remains after 100 timesteps?


Compare gradient flow through vanilla RNN versus LSTM. Adjust the forget gate value and sequence length to see how LSTM preserves gradients for long-term dependencies.

The Constant Error Carousel: How LSTM Preserves Gradients

The cell state in LSTM acts as a "gradient highway" that preserves gradients across many timesteps. Compare how gradients decay in vanilla RNN vs. LSTM.

Higher f = better gradient preservation

Gradient Magnitude at Each Timestep (Flowing Backward from Loss)

LSTM (via Cell State)
Vanilla RNN
t=1 (earliest)← Gradient flows backward ←t=10 (latest/loss)

Why LSTM Gradients Survive

Vanilla RNN Gradient:
∂L/∂h1 = ∏ WhhT · diag(σ')
Each step multiplies by Whh

Problem: If ||Whh|| < 1, gradients shrink exponentially

LSTM Cell State Gradient:
∂Ct/∂Ct-1 = ft
Each step only multiplies by f

Solution: When f &approx; 1, gradients flow unchanged!

Vanilla RNN Gradient at t=1

6.42e-3

Gradient is weak but may allow some learning.

LSTM Gradient at t=1

3.47e-1

Strong gradient! Can learn long-term dependencies.

54x better than RNN

The Constant Error Carousel: Hochreiter and Schmidhuber (1997) called the cell state a "constant error carousel" because when the forget gate f &approx; 1 and input gate i &approx; 0, the cell state update becomes Ct &approx; Ct-1. This means the gradient ∂Ct/∂Ct-1 &approx; 1, allowing errors (gradients) to flow unchanged through many timesteps. The network learns when to preserve and when to update memory.


Building LSTM Intuition

Let's develop intuition for how LSTM processes information through several examples.

Example 1: Remembering a Subject for Verb Agreement

Consider the sentence: "The cat, which was chasing the mice in the garden, is hungry."

The LSTM needs to remember that the subject is "cat" (singular) to correctly predict "is" instead of "are":

  • When processing "cat": The input gate opens (it1i_t \approx 1) to store the singular subject information in the cell state.
  • During "which was chasing the mice in the garden": The forget gate stays high (ft1f_t \approx 1) to preserve the subject information, while the input gate stays low (it0i_t \approx 0) to avoid overwriting it.
  • When predicting the verb: The output gate opens (ot1o_t \approx 1) to access the stored subject information.

Example 2: Language Modeling with Context

"I grew up in France. ... ... ... I speak fluent French."

Even with many sentences in between, the LSTM can remember "France" and use it to predict "French":

  • Storing context: The word "France" triggers high input gate activation, storing location information.
  • Preserving over time: High forget gate values maintain this information across the intervening text.
  • Using context: When predicting the language, the network accesses the stored location to generate "French".

Example 3: Closing Brackets in Code

Matching opening and closing brackets requires counting: "((())" has 3 open brackets.

  • Opening bracket "(": Increment the cell state (itC~t>0i_t \cdot \tilde{C}_t > 0).
  • Closing bracket ")": Decrement the cell state (itC~t<0i_t \cdot \tilde{C}_t < 0).
  • Predicting next token: If cell state is positive, more closing brackets are needed.

Why These Examples Work

In each example, the LSTM uses its gates to selectively store information (input gate), preserve it over time (forget gate), and access it when needed (output gate). The cell state acts as a persistent memory that can maintain information across many timesteps without decay.

Summary

LSTM is a carefully designed architecture that solves the vanishing gradient problem through several key innovations:

Key Concepts

ConceptInnovationPurpose
Cell StateAdditive updates (Cₜ = f·C + i·C̃)Enables gradient flow without decay
Forget GateLearned forgetting (fₜ = σ(...))Selective memory erasure
Input GateLearned writing (iₜ = σ(...))Selective memory writing
Output GateLearned reading (oₜ = σ(...))Selective memory exposure
Dual StateCell state + Hidden stateInternal memory vs. external output

Key Equations

  1. Gates: ft,it,ot=σ(W[ht1,xt]+b)f_t, i_t, o_t = \sigma(W \cdot [h_{t-1}, x_t] + b)
  2. Candidate: C~t=tanh(WC[ht1,xt]+bC)\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)
  3. Cell update: Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
  4. Output: ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)
  5. Gradient: CtCt1=ft\frac{\partial C_t}{\partial C_{t-1}} = f_t (the key insight!)

Looking Forward

In the next section, we'll implement a complete LSTM from scratch in PyTorch and train it on real sequence tasks. We'll see how the theory translates to practice and explore training techniques specific to LSTMs.


Knowledge Check

Test your understanding of LSTM architecture and mechanics:

LSTM Knowledge Check

Question 1 of 8

What is the primary purpose of the forget gate in an LSTM?

Score: 0/0

Exercises

Conceptual Questions

  1. Explain why the forget gate is crucial for learning long-term dependencies. What would happen if we removed the forget gate and always had ft=1f_t = 1?
  2. The output gate allows the LSTM to "hide" information in the cell state without exposing it in the hidden state. Give an example scenario where this capability would be useful.
  3. Compare the number of learnable parameters in a vanilla RNN cell vs. an LSTM cell with the same hidden size. Why is LSTM more parameter-efficient than simply stacking multiple RNN layers?
  4. Explain the role of tanh in the cell state update. Why is tanh applied before the output gate but not before the forget/input operations on the cell state?

Mathematical Exercises

  1. Gradient Computation: Derive the gradient CTC1\frac{\partial C_T}{\partial C_1} for a sequence of length TT. Show that it equals t=2Tft\prod_{t=2}^{T} f_t.
  2. Full Gradient: The gradient through LSTM also flows through the hidden state path. Write out the complete gradient LC1\frac{\partial \mathcal{L}}{\partial C_1} including both the cell state and hidden state paths.
  3. Parameter Count: For an LSTM with input dimension dd and hidden dimension nn, calculate the total number of learnable parameters.

Coding Exercises

  1. Gate Visualization: Implement a function that processes a sequence through an LSTM and records the forget gate values at each timestep. Visualize these as a heatmap for a sentence processing task.
  2. Memory Persistence Test: Create a synthetic task where the network must remember a binary value for nn timesteps before using it. Compare vanilla RNN and LSTM performance as nn increases.
  3. Gradient Analysis: Modify the LSTM implementation to track gradient magnitudes at each timestep during backpropagation. Compare to vanilla RNN and verify that LSTM maintains better gradient flow.

Solution Hints

  • Exercise 1: Register forward hooks on the LSTM layer to capture intermediate activations.
  • Exercise 2: The "copy memory" task: input a bit, then nn zeros, then a signal to output the original bit.
  • Exercise 3: Use retain_graph=True in backward() and register backward hooks on the cell state tensor.

Challenge Project

Build an LSTM Debugger: Create an interactive tool that visualizes LSTM internals during sequence processing. Include:

  • Real-time visualization of all gate activations
  • Cell state evolution over time
  • Attention-style visualization showing which inputs most strongly affected each output
  • Gradient magnitude tracking through both cell state and hidden state paths
  • Comparison mode to contrast vanilla RNN behavior

Now that you understand the LSTM architecture and why it works, you're ready to implement it from scratch. In the next section, we'll build a complete LSTM in PyTorch, train it on a real task, and explore practical considerations for getting LSTMs to work well in production.