Chapter 15
25 min read
Section 90 of 178

Implementing LSTM

LSTM and GRU

Learning Objectives

By the end of this section, you will be able to:

  1. Implement an LSTM cell from scratch using only basic PyTorch tensor operations
  2. Build a full LSTM layer that processes sequences of any length
  3. Understand the efficiency tricks used in production LSTM implementations
  4. Use PyTorch's built-in LSTM and understand its parameters and outputs
  5. Train an LSTM on a real sequence prediction task from start to finish
  6. Debug common LSTM issues including vanishing gradients, exploding gradients, and memory problems
Why This Matters: Understanding the implementation details of LSTM is crucial for several reasons: (1) you'll gain deep intuition about what each component does by building it yourself, (2) you'll be able to debug issues that arise in production, (3) you'll understand the trade-offs in different implementation choices, and (4) you'll be prepared to modify or extend LSTM for specialized applications. This section bridges the gap between mathematical understanding and practical application.

The Implementation Journey

In the previous section, we understood what LSTM does and why it works. Now we'll learn how to build one. We'll start from the ground up, implementing each component step by step.

Our Implementation Roadmap

StepComponentPurpose
1Single LSTM CellProcess one timestep: (xₜ, hₜ₋₁, Cₜ₋₁) → (hₜ, Cₜ)
2LSTM LayerProcess entire sequence by iterating the cell
3Bidirectional LSTMProcess sequence in both directions
4Stacked LSTMMultiple LSTM layers for deeper representations

Let's start with the most fundamental component: the LSTM cell.


Building an LSTM Cell from Scratch

An LSTM cell takes three inputs and produces two outputs:

  • Inputs: Current input xtx_t, previous hidden state ht1h_{t-1}, previous cell state Ct1C_{t-1}
  • Outputs: New hidden state hth_t, new cell state CtC_t

Internally, it computes four components: forget gate, input gate, candidate cell state, and output gate. Let's implement this step by step.

LSTM Cell: The Clear Implementation
🐍lstm_cell_scratch.py
5Class Definition

We inherit from nn.Module to get automatic parameter registration and gradient tracking.

18Forget Gate Weights

Wf has shape (hidden_size, hidden_size + input_size) because it takes the concatenation of hₜ₋₁ and xₜ as input.

36Xavier Initialization

Xavier (Glorot) initialization helps prevent vanishing/exploding gradients at the start of training by keeping variance stable across layers.

44Forget Gate Bias = 1

Initializing forget gate bias to 1 makes the initial forget gate output close to 1 (after sigmoid). This encourages the LSTM to remember information by default, which helps with learning long-term dependencies.

69Concatenation

We concatenate hₜ₋₁ and xₜ into a single vector. This allows us to use a single weight matrix for each gate instead of two separate matrices.

73Forget Gate

fₜ = σ(Wf @ [h,x] + bf). Sigmoid outputs values in [0,1], where 0 means 'forget everything' and 1 means 'remember everything'.

76Input Gate

iₜ = σ(Wᵢ @ [h,x] + bᵢ). Controls how much of the new candidate information to add to the cell state.

79Candidate Cell State

C̃ₜ = tanh(Wc @ [h,x] + bc). Tanh outputs values in [-1,1], allowing both addition and subtraction from the cell state.

82Output Gate

oₜ = σ(Wₒ @ [h,x] + bₒ). Controls which parts of the cell state are exposed in the hidden state output.

87Cell State Update (Key Equation!)

Cₜ = fₜ × Cₜ₋₁ + iₜ × C̃ₜ. The ADDITION is what enables gradient flow. When f≈1 and i≈0, Cₜ ≈ Cₜ₋₁ and gradients pass through unchanged.

90Hidden State Output

hₜ = oₜ × tanh(Cₜ). The cell state is squashed by tanh and filtered by the output gate. This is what the LSTM 'exposes' to the outside world.

112 lines without explanation
1import torch
2import torch.nn as nn
3from typing import Optional
4
5class LSTMCellFromScratch(nn.Module):
6    """
7    LSTM cell implemented from first principles.
8
9    This implementation prioritizes clarity over efficiency.
10    Each gate is computed separately to match the mathematical formulation.
11    """
12
13    def __init__(self, input_size: int, hidden_size: int):
14        super().__init__()
15        self.input_size = input_size
16        self.hidden_size = hidden_size
17
18        # Forget gate parameters: W_f @ [h, x] + b_f
19        self.W_f = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))
20        self.b_f = nn.Parameter(torch.zeros(hidden_size))
21
22        # Input gate parameters: W_i @ [h, x] + b_i
23        self.W_i = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))
24        self.b_i = nn.Parameter(torch.zeros(hidden_size))
25
26        # Candidate cell state parameters: W_c @ [h, x] + b_c
27        self.W_c = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))
28        self.b_c = nn.Parameter(torch.zeros(hidden_size))
29
30        # Output gate parameters: W_o @ [h, x] + b_o
31        self.W_o = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))
32        self.b_o = nn.Parameter(torch.zeros(hidden_size))
33
34        # Initialize weights using Xavier initialization
35        self._init_weights()
36
37    def _init_weights(self):
38        """Initialize weights for stable training."""
39        std = 1.0 / (self.hidden_size ** 0.5)
40        for param in self.parameters():
41            if param.dim() > 1:
42                nn.init.xavier_uniform_(param)
43            else:
44                nn.init.zeros_(param)
45
46        # Bias forget gate towards remembering (common practice)
47        nn.init.ones_(self.b_f)
48
49    def forward(
50        self,
51        x: torch.Tensor,
52        state: Optional[tuple[torch.Tensor, torch.Tensor]] = None
53    ) -> tuple[torch.Tensor, torch.Tensor]:
54        """
55        Forward pass for one timestep.
56
57        Args:
58            x: Input tensor of shape (batch_size, input_size)
59            state: Tuple of (h_prev, c_prev), each (batch_size, hidden_size)
60                   If None, initializes to zeros.
61
62        Returns:
63            Tuple of (h_t, c_t), each (batch_size, hidden_size)
64        """
65        batch_size = x.size(0)
66
67        # Initialize state if not provided
68        if state is None:
69            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
70            c_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
71        else:
72            h_prev, c_prev = state
73
74        # Concatenate h_{t-1} and x_t: [h, x]
75        combined = torch.cat([h_prev, x], dim=1)
76
77        # ===== Gate Computations =====
78
79        # Forget gate: f_t = σ(W_f @ [h, x] + b_f)
80        f_t = torch.sigmoid(combined @ self.W_f.T + self.b_f)
81
82        # Input gate: i_t = σ(W_i @ [h, x] + b_i)
83        i_t = torch.sigmoid(combined @ self.W_i.T + self.b_i)
84
85        # Candidate cell state: C̃_t = tanh(W_c @ [h, x] + b_c)
86        c_tilde = torch.tanh(combined @ self.W_c.T + self.b_c)
87
88        # Output gate: o_t = σ(W_o @ [h, x] + b_o)
89        o_t = torch.sigmoid(combined @ self.W_o.T + self.b_o)
90
91        # ===== State Updates =====
92
93        # Cell state update: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
94        c_t = f_t * c_prev + i_t * c_tilde
95
96        # Hidden state output: h_t = o_t ⊙ tanh(C_t)
97        h_t = o_t * torch.tanh(c_t)
98
99        return h_t, c_t
100
101
102# Test the implementation
103def test_lstm_cell():
104    batch_size = 4
105    input_size = 10
106    hidden_size = 20
107
108    cell = LSTMCellFromScratch(input_size, hidden_size)
109
110    # Single input
111    x = torch.randn(batch_size, input_size)
112    h, c = cell(x)
113
114    print(f"Input shape: {x.shape}")
115    print(f"Hidden state shape: {h.shape}")
116    print(f"Cell state shape: {c.shape}")
117
118    # Verify shapes
119    assert h.shape == (batch_size, hidden_size)
120    assert c.shape == (batch_size, hidden_size)
121    print("✓ All shape checks passed!")
122
123test_lstm_cell()

This Implementation Is Intentionally Slow

The above implementation uses four separate matrix multiplications (one per gate). Production implementations combine these into a single matrix multiply for efficiency. We'll show that optimization next.

Step-by-Step Implementation Walkthrough

Let's trace through what happens when we process a single timestep with concrete numbers.

Example: Processing One Timestep

Suppose we have:

  • Input size: 3 (e.g., 3-dimensional word embedding)
  • Hidden size: 4
  • Current input: xt=[0.5,0.2,0.8]x_t = [0.5, -0.2, 0.8]
  • Previous hidden state: ht1=[0.1,0.3,0.1,0.2]h_{t-1} = [0.1, 0.3, -0.1, 0.2]
  • Previous cell state: Ct1=[1.0,0.5,0.3,0.7]C_{t-1} = [1.0, -0.5, 0.3, 0.7]

Step 1: Concatenation

First, we concatenate ht1h_{t-1} and xtx_t:

[ht1,xt]=[0.1,0.3,0.1,0.2,0.5,0.2,0.8][h_{t-1}, x_t] = [0.1, 0.3, -0.1, 0.2, 0.5, -0.2, 0.8]

This combined vector has dimension 4+3=74 + 3 = 7.

Step 2: Gate Computations

Each gate computes a linear transformation followed by an activation:

GateFormulaActivationOutput Range
Forget (fₜ)Wf @ [h,x] + bfSigmoid[0, 1]
Input (iₜ)Wᵢ @ [h,x] + bᵢSigmoid[0, 1]
Candidate (C̃ₜ)Wc @ [h,x] + bcTanh[-1, 1]
Output (oₜ)Wₒ @ [h,x] + bₒSigmoid[0, 1]

Step 3: State Updates

The cell state is updated using element-wise operations:

Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t

Suppose after computing the gates we get:

  • ft=[0.9,0.1,0.8,0.95]f_t = [0.9, 0.1, 0.8, 0.95] (keep most of C[0], C[2], C[3]; forget C[1])
  • it=[0.3,0.9,0.2,0.1]i_t = [0.3, 0.9, 0.2, 0.1] (add a lot to C[1], little to others)
  • C~t=[0.5,0.8,0.3,0.1]\tilde{C}_t = [0.5, 0.8, -0.3, 0.1] (candidate values to potentially add)

Then the new cell state is:

Ct=ftCt1+itC~t=[0.9,0.1,0.8,0.95][1.0,0.5,0.3,0.7]+[0.3,0.9,0.2,0.1][0.5,0.8,0.3,0.1]=[0.9,0.05,0.24,0.665]+[0.15,0.72,0.06,0.01]=[1.05,0.67,0.18,0.675]\begin{aligned} C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ &= [0.9, 0.1, 0.8, 0.95] \odot [1.0, -0.5, 0.3, 0.7] + [0.3, 0.9, 0.2, 0.1] \odot [0.5, 0.8, -0.3, 0.1] \\ &= [0.9, -0.05, 0.24, 0.665] + [0.15, 0.72, -0.06, 0.01] \\ &= [1.05, 0.67, 0.18, 0.675] \end{aligned}

Observing the Dynamics

Notice how C[1]C[1] changed dramatically: the forget gate was low (0.1) so the old value (-0.5) was mostly forgotten, and the input gate was high (0.9) so the new candidate (0.8) was added. Meanwhile, C[0]C[0] and C[3]C[3] barely changed because their forget gates were high and input gates were low.

Quick Check

If the forget gate outputs fₜ = [1, 1, 1, 1] and the input gate outputs iₜ = [0, 0, 0, 0], what is Cₜ?


Interactive: LSTM Data Flow Visualization

Before diving into code, let's visualize exactly how data flows through an LSTM cell. This interactive diagram shows the step-by-step calculations with actual numbers. Watch how the Long-Term Memory (cell state) flows along the top, while Short-Term Memory (hidden state) flows along the bottom.

LSTM Data Flow: Step-by-Step Calculations

Watch data flow through the LSTM cell with actual calculations. Click on any step to see the math!

Step 1 / 9
Initial State
Starting values: Long-term memory (Cell State), Short-term memory (Hidden State), and Input
Long-Term Memory
(Cell State)
2.00
×
sum
New Long-Term Memory
?
Ct
% Long-Term
To Remember
?
× 1.00h×W
× 1.00x×W
+ bias
= 2.80
% Potential Memory
To Remember
?
× 1.00h×W
× 1.00x×W
+ bias
= 1.60
Potential Long-Term
Memory
tanh
?
× 1.00h×W
× 1.00x×W
+ bias
= 1.70
% Potential Memory
To Output
?
× 1.00h×W
× 1.00x×W
+ bias
= 1.60
New Short-Term
Memory
?
ht
×
Short-Term Memory
(Hidden State)
1.00
Input
1.00
ht

Current Calculation

Initial values:

• Long-Term Memory Ct-1 = 2.00

• Short-Term Memory ht-1 = 1.00

• Input xt = 1.00

Key Insight: Notice how the cell state update uses addition: Ct = f × Ct-1 + i × C̃. This additive structure is what enables gradient flow! When the forget gate f ≈ 1 and input gate i ≈ 0, the cell state passes through almost unchanged, allowing gradients to flow backward through many timesteps without vanishing.

Building a Full LSTM Layer

An LSTM cell processes one timestep. To process an entire sequence, we need an LSTM layer that iterates the cell across all timesteps. We'll also implement the efficient version that combines all gate computations.

Efficient LSTM Layer Implementation
🐍lstm_layer.py
5LSTM Layer Class

An LSTM layer wraps the cell logic and adds sequence processing. The batch_first parameter controls tensor shape conventions.

27Combined Weight Matrices

Instead of 4 separate weight matrices, we use one matrix of size (4*hidden, input) for input weights and one of size (4*hidden, hidden) for recurrent weights. This enables a single GEMM operation per timestep.

40Initialization Strategy

We use Xavier for input weights and orthogonal for hidden weights. Orthogonal initialization keeps eigenvalues close to 1, which helps gradient flow in recurrent connections.

48Forget Gate Bias

Setting the forget gate bias to 1 is a crucial initialization trick. It makes the LSTM remember by default, which is especially important early in training when the network hasn't learned what to forget.

68Transpose for seq_first

We internally process sequences in (seq, batch, features) format for efficiency, but allow users to pass (batch, seq, features) with batch_first=True.

87Efficient Gate Computation

This is the key optimization: computing all 4 gates in a single matrix multiplication. We get 4x speedup compared to computing each gate separately.

94Chunk Operation

torch.chunk splits the combined gate output into 4 equal parts along the hidden dimension. Each part corresponds to one gate: forget, input, candidate, output.

103State Updates

Standard LSTM equations: cell state uses additive update (enables gradient flow), hidden state uses multiplicative gating of cell state.

112Dropout

Dropout is applied to outputs, not recurrent connections. Applying dropout to recurrent connections can hurt performance (see 'variational dropout' for alternatives).

150 lines without explanation
1import torch
2import torch.nn as nn
3from typing import Optional
4
5class LSTMLayer(nn.Module):
6    """
7    Efficient LSTM layer that processes full sequences.
8
9    This implementation combines all four gate computations into
10    a single matrix multiplication for 4x speedup.
11    """
12
13    def __init__(
14        self,
15        input_size: int,
16        hidden_size: int,
17        batch_first: bool = True,
18        dropout: float = 0.0,
19    ):
20        super().__init__()
21        self.input_size = input_size
22        self.hidden_size = hidden_size
23        self.batch_first = batch_first
24        self.dropout = dropout
25
26        # Combined weights for all 4 gates: [f, i, c, o]
27        # Shape: (4 * hidden_size, input_size) for input weights
28        # Shape: (4 * hidden_size, hidden_size) for hidden weights
29        self.weight_ih = nn.Parameter(
30            torch.randn(4 * hidden_size, input_size)
31        )
32        self.weight_hh = nn.Parameter(
33            torch.randn(4 * hidden_size, hidden_size)
34        )
35        self.bias = nn.Parameter(torch.zeros(4 * hidden_size))
36
37        # Dropout layer (applied to outputs, not recurrent connections)
38        self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
39
40        self._init_weights()
41
42    def _init_weights(self):
43        """Initialize weights using orthogonal initialization."""
44        # Xavier/Glorot for input weights
45        nn.init.xavier_uniform_(self.weight_ih)
46
47        # Orthogonal for hidden weights (helps with gradient flow)
48        nn.init.orthogonal_(self.weight_hh)
49
50        # Bias: zeros, except forget gate bias = 1
51        nn.init.zeros_(self.bias)
52        # Forget gate bias (first hidden_size elements)
53        self.bias.data[:self.hidden_size].fill_(1.0)
54
55    def forward(
56        self,
57        x: torch.Tensor,
58        state: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
59    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
60        """
61        Process an entire sequence.
62
63        Args:
64            x: Input sequence
65               - If batch_first: (batch, seq_len, input_size)
66               - Otherwise: (seq_len, batch, input_size)
67            state: Initial (h_0, c_0), each of shape (batch, hidden_size)
68
69        Returns:
70            outputs: All hidden states, shape matching input
71            (h_n, c_n): Final hidden and cell states
72        """
73        # Handle batch_first
74        if self.batch_first:
75            x = x.transpose(0, 1)  # (seq, batch, input)
76
77        seq_len, batch_size, _ = x.shape
78
79        # Initialize state if not provided
80        if state is None:
81            h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)
82            c_t = torch.zeros(batch_size, self.hidden_size, device=x.device)
83        else:
84            h_t, c_t = state
85
86        # Collect outputs
87        outputs = []
88
89        # Process sequence timestep by timestep
90        for t in range(seq_len):
91            x_t = x[t]  # (batch, input_size)
92
93            # === Efficient Gate Computation ===
94            # Single matrix multiply computes all 4 gates at once
95            # gates = x_t @ W_ih^T + h_t @ W_hh^T + bias
96            gates = (
97                x_t @ self.weight_ih.T +
98                h_t @ self.weight_hh.T +
99                self.bias
100            )
101
102            # Split into 4 gates, each of shape (batch, hidden_size)
103            f_gate, i_gate, c_gate, o_gate = gates.chunk(4, dim=1)
104
105            # Apply activations
106            f_t = torch.sigmoid(f_gate)
107            i_t = torch.sigmoid(i_gate)
108            c_tilde = torch.tanh(c_gate)
109            o_t = torch.sigmoid(o_gate)
110
111            # Update states
112            c_t = f_t * c_t + i_t * c_tilde
113            h_t = o_t * torch.tanh(c_t)
114
115            outputs.append(h_t)
116
117        # Stack outputs: (seq_len, batch, hidden_size)
118        output = torch.stack(outputs, dim=0)
119
120        # Apply dropout to output (not recurrent connections)
121        output = self.drop(output)
122
123        # Handle batch_first for output
124        if self.batch_first:
125            output = output.transpose(0, 1)  # (batch, seq, hidden)
126
127        return output, (h_t, c_t)
128
129
130# Test and benchmark
131def test_lstm_layer():
132    batch_size = 32
133    seq_len = 100
134    input_size = 64
135    hidden_size = 128
136
137    lstm = LSTMLayer(input_size, hidden_size, batch_first=True)
138
139    # Random sequence
140    x = torch.randn(batch_size, seq_len, input_size)
141
142    # Forward pass
143    output, (h_n, c_n) = lstm(x)
144
145    print(f"Input shape: {x.shape}")
146    print(f"Output shape: {output.shape}")
147    print(f"Final hidden state shape: {h_n.shape}")
148    print(f"Final cell state shape: {c_n.shape}")
149
150    # Verify shapes
151    assert output.shape == (batch_size, seq_len, hidden_size)
152    assert h_n.shape == (batch_size, hidden_size)
153    assert c_n.shape == (batch_size, hidden_size)
154
155    # Verify output[-1] == h_n (last output is final hidden state)
156    assert torch.allclose(output[:, -1, :], h_n, atol=1e-6)
157    print("✓ All checks passed!")
158
159test_lstm_layer()

Why Combine Gates?

Combining all gates into a single matrix multiplication provides significant speedup because:

  1. GPU parallelism: Large matrix multiplications are highly parallelized on GPUs
  2. Memory efficiency: One large matrix operation has less overhead than four small ones
  3. Cache efficiency: Better memory access patterns when reading input once for all gates

On a modern GPU, this optimization can provide 3-4x speedup over the naive implementation.


Interactive: LSTM Implementation Explorer

Explore how data flows through the LSTM cell. Adjust inputs and observe how the gates and states change in real-time.

LSTM Cell Implementation Explorer
Progress:
Ready

Forget Gate

sigmoid
Pre:1.050
Post:0.741

Input Gate

sigmoid
Pre:0.470
Post:0.615

Candidate

tanh
Pre:0.480
Post:0.446

Output Gate

sigmoid
Pre:0.510
Post:0.625

State Updates

Cell State Update
Ct = ft × Ct-1 + it × C̃t
0.741×0.700+0.615×0.446=0.793
Hidden State Output
ht = ot × tanh(Ct)
0.625×tanh(0.793)=0.412

Key Insight

Notice the addition in the cell state update: Ct = f × Ct-1 + i × C̃. This is what enables gradient flow! When f ≈ 1 and i ≈ 0, the gradient flows unchanged through Ct → Ct-1. Try setting the input slider to 0 and observe how the forget gate dominates.


PyTorch's Built-in LSTM

PyTorch provides a highly optimized LSTM implementation. Let's understand how to use it and verify our implementation matches it.

Using PyTorch's Built-in LSTM
🐍pytorch_lstm.py
19nn.LSTM Parameters

PyTorch's LSTM supports multiple layers, bidirectional processing, and dropout between layers—all in one class.

22batch_first=True

When True, expects input as (batch, seq, features). Default is False: (seq, batch, features). Match your data format.

23Dropout Between Layers

This dropout is applied between LSTM layers (for num_layers > 1), NOT after the last layer. Add your own dropout after if needed.

47Packed Sequences

For variable-length sequences, packing avoids wasted computation on padding. The LSTM skips padded positions entirely.

61hₙ Shape

hₙ has shape (num_layers × num_directions, batch, hidden). For multi-layer bidirectional: [L0_fwd, L0_bwd, L1_fwd, L1_bwd, ...].

65Bidirectional Final State

For bidirectional LSTM, concatenate forward and backward final hidden states. hₙ[-2] is last layer forward, hₙ[-1] is last layer backward.

96Output vs hₙ

output contains hidden states for ALL timesteps. hₙ contains only the FINAL hidden state per layer. For classification, you typically use hₙ.

111 lines without explanation
1import torch
2import torch.nn as nn
3
4# PyTorch's built-in LSTM
5class ModelWithBuiltinLSTM(nn.Module):
6    """
7    Example model using PyTorch's nn.LSTM.
8    """
9
10    def __init__(
11        self,
12        input_size: int = 64,
13        hidden_size: int = 128,
14        num_layers: int = 2,
15        dropout: float = 0.1,
16        bidirectional: bool = False,
17    ):
18        super().__init__()
19
20        self.lstm = nn.LSTM(
21            input_size=input_size,
22            hidden_size=hidden_size,
23            num_layers=num_layers,
24            batch_first=True,          # (batch, seq, features)
25            dropout=dropout,            # Between layers (not after last)
26            bidirectional=bidirectional,
27        )
28
29        # Calculate output size
30        self.output_size = hidden_size * (2 if bidirectional else 1)
31
32        # Final projection layer
33        self.fc = nn.Linear(self.output_size, 10)  # e.g., 10 classes
34
35    def forward(
36        self,
37        x: torch.Tensor,
38        lengths: Optional[torch.Tensor] = None,
39    ) -> tuple[torch.Tensor, torch.Tensor]:
40        """
41        Args:
42            x: (batch, seq_len, input_size)
43            lengths: Optional sequence lengths for packing
44
45        Returns:
46            logits: (batch, 10) class scores
47            hidden: Final hidden state
48        """
49        # Handle variable-length sequences efficiently
50        if lengths is not None:
51            # Pack for efficient computation
52            x_packed = nn.utils.rnn.pack_padded_sequence(
53                x, lengths.cpu(), batch_first=True, enforce_sorted=False
54            )
55            output_packed, (h_n, c_n) = self.lstm(x_packed)
56            # Unpack
57            output, _ = nn.utils.rnn.pad_packed_sequence(
58                output_packed, batch_first=True
59            )
60        else:
61            output, (h_n, c_n) = self.lstm(x)
62
63        # h_n shape: (num_layers * num_directions, batch, hidden)
64        # For classification, typically use the last layer's hidden state
65
66        if self.lstm.bidirectional:
67            # Concatenate forward and backward final hidden states
68            # h_n[-2] is forward, h_n[-1] is backward
69            hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
70        else:
71            hidden = h_n[-1]  # Last layer hidden state
72
73        logits = self.fc(hidden)
74        return logits, hidden
75
76
77# Understanding LSTM outputs
78def understand_lstm_outputs():
79    """
80    Demystify nn.LSTM output shapes.
81    """
82    batch_size = 4
83    seq_len = 10
84    input_size = 32
85    hidden_size = 64
86    num_layers = 3
87
88    lstm = nn.LSTM(
89        input_size, hidden_size,
90        num_layers=num_layers,
91        batch_first=True,
92        bidirectional=True
93    )
94
95    x = torch.randn(batch_size, seq_len, input_size)
96    output, (h_n, c_n) = lstm(x)
97
98    print("=== nn.LSTM Output Shapes ===")
99    print(f"Input: {x.shape}")
100    print(f"  → (batch={batch_size}, seq={seq_len}, input={input_size})")
101    print()
102    print(f"Output: {output.shape}")
103    print(f"  → (batch={batch_size}, seq={seq_len}, hidden*directions={hidden_size*2})")
104    print(f"  → Contains hidden states for ALL timesteps")
105    print()
106    print(f"h_n: {h_n.shape}")
107    print(f"  → (layers*directions={num_layers*2}, batch={batch_size}, hidden={hidden_size})")
108    print(f"  → Contains FINAL hidden state for each layer")
109    print()
110    print(f"c_n: {c_n.shape}")
111    print(f"  → (layers*directions={num_layers*2}, batch={batch_size}, hidden={hidden_size})")
112    print(f"  → Contains FINAL cell state for each layer")
113
114    # Key insight: output[:, -1, :] contains the last timestep
115    # but for bidirectional, you need output[:, -1, :hidden] for forward
116    # and output[:, 0, hidden:] for backward (processes in reverse)
117
118understand_lstm_outputs()

Common Confusion: output vs hₙ

OutputShapeContains
output(batch, seq, hidden×dir)Hidden state at EVERY timestep
hₙ(layers×dir, batch, hidden)FINAL hidden state per layer
cₙ(layers×dir, batch, hidden)FINAL cell state per layer

For sequence classification, use hₙ (final state). For sequence-to-sequence tasks (like translation), use output (all states).


Training LSTM on a Real Task

Let's train an LSTM on a practical sequence prediction task: predicting the next character in a sequence. This task clearly demonstrates long-term dependencies.

Training a Character-Level LSTM Language Model
🐍train_lstm.py
8Character-Level Model

Character-level models predict the next character given previous characters. They can learn spelling, syntax, and even generate reasonable text without explicit word tokenization.

25Embedding Layer

Converts character indices to dense vectors. The model learns these embeddings during training.

41Weight Tying

An optional technique where the output projection shares weights with the embedding. This reduces parameters and often improves performance.

57Dropout on Embeddings

Apply dropout to embeddings during training. This regularizes the model and prevents overfitting to specific characters.

85Temperature Sampling

Temperature controls randomness: T=1.0 is standard, T<1.0 is more deterministic (sharper peaks), T>1.0 is more random (flatter distribution).

96Autoregressive Generation

We feed each generated character back as input to generate the next one. The state is passed along to maintain context.

123Learning Rate Scheduler

ReduceLROnPlateau decreases the learning rate when validation loss stops improving, helping the model converge better.

149Gradient Clipping

Essential for RNNs! Clips gradients to prevent explosion. Typical values are 1.0-5.0. Without this, LSTM training often diverges.

180Perplexity

Perplexity = exp(loss) measures how 'surprised' the model is. Lower is better. PPL of 10 means the model is as uncertain as choosing uniformly from 10 options.

201 lines without explanation
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader, Dataset
5import numpy as np
6from typing import Optional
7
8class CharacterLSTM(nn.Module):
9    """
10    Character-level language model using LSTM.
11    Predicts the next character given previous characters.
12    """
13
14    def __init__(
15        self,
16        vocab_size: int,
17        embedding_dim: int = 64,
18        hidden_size: int = 256,
19        num_layers: int = 2,
20        dropout: float = 0.2,
21    ):
22        super().__init__()
23        self.hidden_size = hidden_size
24        self.num_layers = num_layers
25
26        # Character embedding
27        self.embedding = nn.Embedding(vocab_size, embedding_dim)
28
29        # LSTM layers
30        self.lstm = nn.LSTM(
31            input_size=embedding_dim,
32            hidden_size=hidden_size,
33            num_layers=num_layers,
34            batch_first=True,
35            dropout=dropout if num_layers > 1 else 0,
36        )
37
38        # Output projection
39        self.fc = nn.Linear(hidden_size, vocab_size)
40
41        # Dropout for regularization
42        self.drop = nn.Dropout(dropout)
43
44        # Weight tying (optional but helpful)
45        # self.fc.weight = self.embedding.weight
46
47    def forward(
48        self,
49        x: torch.Tensor,
50        state: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
51    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
52        """
53        Args:
54            x: Character indices, shape (batch, seq_len)
55            state: Optional (h, c) for continuation
56
57        Returns:
58            logits: (batch, seq_len, vocab_size)
59            state: (h_n, c_n) for next forward pass
60        """
61        # Embed characters
62        embedded = self.drop(self.embedding(x))  # (batch, seq, embed)
63
64        # LSTM forward
65        lstm_out, state = self.lstm(embedded, state)  # (batch, seq, hidden)
66
67        # Project to vocabulary
68        logits = self.fc(self.drop(lstm_out))  # (batch, seq, vocab)
69
70        return logits, state
71
72    def generate(
73        self,
74        start_chars: str,
75        char_to_idx: dict,
76        idx_to_char: dict,
77        length: int = 100,
78        temperature: float = 1.0,
79        device: str = "cpu",
80    ) -> str:
81        """
82        Generate text character by character.
83
84        Args:
85            start_chars: Seed string to start generation
86            char_to_idx: Vocabulary mapping
87            idx_to_char: Inverse vocabulary mapping
88            length: Number of characters to generate
89            temperature: Sampling temperature (higher = more random)
90            device: Device to run on
91        """
92        self.eval()
93
94        # Convert start chars to tensor
95        chars = [char_to_idx.get(c, 0) for c in start_chars]
96        x = torch.tensor([chars], device=device)
97
98        generated = list(start_chars)
99        state = None
100
101        with torch.no_grad():
102            # Process initial context
103            logits, state = self(x, state)
104
105            for _ in range(length):
106                # Get next character prediction (last position)
107                next_logits = logits[0, -1, :] / temperature
108
109                # Sample from distribution
110                probs = torch.softmax(next_logits, dim=0)
111                next_idx = torch.multinomial(probs, 1).item()
112
113                # Append to generated text
114                next_char = idx_to_char[next_idx]
115                generated.append(next_char)
116
117                # Prepare next input
118                x = torch.tensor([[next_idx]], device=device)
119                logits, state = self(x, state)
120
121        return ''.join(generated)
122
123
124# Training loop
125def train_character_lstm(
126    model: CharacterLSTM,
127    train_loader: DataLoader,
128    val_loader: DataLoader,
129    epochs: int = 10,
130    lr: float = 0.001,
131    clip_grad: float = 5.0,
132    device: str = "cpu",
133):
134    """
135    Complete training loop with gradient clipping and validation.
136    """
137    model = model.to(device)
138    optimizer = optim.Adam(model.parameters(), lr=lr)
139    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
140        optimizer, mode='min', factor=0.5, patience=2
141    )
142    criterion = nn.CrossEntropyLoss()
143
144    best_val_loss = float('inf')
145
146    for epoch in range(epochs):
147        # Training
148        model.train()
149        train_loss = 0.0
150
151        for batch_idx, (x, y) in enumerate(train_loader):
152            x, y = x.to(device), y.to(device)
153
154            optimizer.zero_grad()
155
156            # Forward pass
157            logits, _ = model(x)
158
159            # Reshape for loss: (batch * seq, vocab) vs (batch * seq)
160            loss = criterion(
161                logits.view(-1, logits.size(-1)),
162                y.view(-1)
163            )
164
165            # Backward pass
166            loss.backward()
167
168            # Gradient clipping (crucial for RNNs!)
169            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
170
171            optimizer.step()
172
173            train_loss += loss.item()
174
175        avg_train_loss = train_loss / len(train_loader)
176
177        # Validation
178        model.eval()
179        val_loss = 0.0
180
181        with torch.no_grad():
182            for x, y in val_loader:
183                x, y = x.to(device), y.to(device)
184                logits, _ = model(x)
185                loss = criterion(
186                    logits.view(-1, logits.size(-1)),
187                    y.view(-1)
188                )
189                val_loss += loss.item()
190
191        avg_val_loss = val_loss / len(val_loader)
192
193        # Learning rate scheduling
194        scheduler.step(avg_val_loss)
195
196        # Compute perplexity
197        train_ppl = np.exp(avg_train_loss)
198        val_ppl = np.exp(avg_val_loss)
199
200        print(f"Epoch {epoch+1}/{epochs}")
201        print(f"  Train Loss: {avg_train_loss:.4f}, PPL: {train_ppl:.2f}")
202        print(f"  Val Loss: {avg_val_loss:.4f}, PPL: {val_ppl:.2f}")
203
204        # Save best model
205        if avg_val_loss < best_val_loss:
206            best_val_loss = avg_val_loss
207            torch.save(model.state_dict(), "best_char_lstm.pt")
208            print(f"  ✓ Saved best model (val_loss: {avg_val_loss:.4f})")
209
210    return model

Gradient Clipping Is Essential

Even with LSTM's improved gradient flow, gradients can still explode on difficult sequences. Always use gradient clipping when training RNNs:

  • clip_grad=1.0\text{clip\_grad} = 1.0 to 5.05.0 are typical values
  • Monitor gradient norms during training to tune this value
  • If training is unstable, try reducing clip value

Interactive: Training Dashboard

Observe LSTM training in action. Watch how the loss decreases, gate activations evolve, and generated text improves over epochs.

LSTM Training Dashboard
Speed:
Epoch 1/10Step 0/100
Train Loss
4.488
Val Loss
4.932
Perplexity
89.0
Grad Norm
2.70

Loss History

Train
Validation

Gate Activations (10 hidden units)

Forget Gate (fₜ)
Input Gate (iₜ)
Output Gate (oₜ)

As training progresses, gate activations become more structured and specialized. Watch how the patterns evolve from random to meaningful.

Sample Generation

Prompt: "The cat" → thxe caat sxt on the maxt

Text quality improves as the model learns character-level patterns and long-term dependencies.

Learning Rate1.00e-3

Comparing RNN vs LSTM Performance

Let's empirically compare vanilla RNN and LSTM on tasks requiring different memory lengths.

RNN vs LSTM: The Memory Test

The Copy Task

The network must memorize a sequence of random bits, wait through a delay period, then reproduce the sequence exactly. As the delay increases, RNN performance collapses while LSTM maintains its accuracy.

[1 0 1 1 0][0 0 0 ... 0](delay)[1 0 1 1 0](reproduce)
50 steps

Vanilla RNN

Accuracy at 50 steps
30.0%
Gradient at first timestep
4.44e-1

LSTM

Accuracy at 50 steps
95.1%
Gradient at first timestep
9.04e-1
Healthy

Accuracy vs Sequence Length

0%25%50%75%100%52050100200
RNN
LSTM

Why This Happens

RNN: Gradient magnitude = (Wₕₕ)ᵀ per timestep. With ||Wₕₕ|| = 0.85, after 100 steps: 0.85¹⁰⁰ ≈ 10⁻⁷. The gradient has effectively vanished.

LSTM: Gradient magnitude ≈ (fₜ)ᵀ through cell state. With fₜ ≈ 0.98, after 100 steps: 0.98¹⁰⁰ ≈ 0.13. The gradient remains usable for learning.

The visualization above demonstrates the key advantage of LSTM: as the sequence length increases, vanilla RNN performance degrades rapidly due to vanishing gradients, while LSTM maintains its ability to learn long-term dependencies.

AspectVanilla RNNLSTM
Parameters (hidden=128)~16K~66K (4x more gates)
Effective memory~10-20 steps~100-500+ steps
Training speed/stepFasterSlower (4x computation)
Gradient flowExponential decayControlled by forget gate
Best use caseShort sequences, real-timeLong sequences, complex patterns

Practical Tips and Best Practices

Here are essential practices for training LSTM networks effectively:

1. Initialization

ParameterRecommendationReason
Input weightsXavier/Glorot uniformMaintains variance through layers
Hidden weightsOrthogonalEigenvalues near 1, good gradient flow
Forget gate bias1.0Start with remembering by default
Other biases0.0Standard initialization

2. Regularization

  • Dropout: Apply to inputs and outputs, NOT recurrent connections. Use 0.2-0.5 depending on dataset size.
  • Weight decay: Small values (1e-5 to 1e-4) help prevent overfitting
  • Gradient clipping: Clip norm to 1.0-5.0 to prevent explosion
  • Early stopping: Monitor validation loss and stop when it stops improving

3. Learning Rate

  • Start with 1e-3 for Adam, 1.0 for SGD
  • Use learning rate warmup for the first 1000-5000 steps
  • Use ReduceLROnPlateau or cosine annealing for decay

4. Sequence Handling

Handling Variable-Length Sequences
🐍sequence_handling.py
15Track Original Lengths

Store the original length of each sequence before padding. This is needed for packing and for computing metrics correctly.

18Sort by Length

pack_padded_sequence requires sequences sorted by length (descending). We track the sort indices to unsort later.

23Padding

pad_sequence pads all sequences to the length of the longest. The default padding value is 0.

31Packing

pack_padded_sequence creates a PackedSequence that allows the LSTM to skip padded positions entirely, saving computation.

37Unpacking

After LSTM processing, unpack back to padded format. The output will have zeros at padded positions.

50 lines without explanation
1import torch
2from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
3
4def prepare_batch(sequences: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
5    """
6    Prepare a batch of variable-length sequences for LSTM.
7
8    Args:
9        sequences: List of tensors, each (seq_len_i, features)
10
11    Returns:
12        padded: (batch, max_seq_len, features)
13        lengths: (batch,) original lengths
14    """
15    # Get lengths before padding
16    lengths = torch.tensor([seq.size(0) for seq in sequences])
17
18    # Sort by length (descending) for pack_padded_sequence
19    sorted_indices = torch.argsort(lengths, descending=True)
20    sorted_sequences = [sequences[i] for i in sorted_indices]
21    sorted_lengths = lengths[sorted_indices]
22
23    # Pad to same length
24    padded = pad_sequence(sorted_sequences, batch_first=True)
25
26    return padded, sorted_lengths, sorted_indices
27
28
29def lstm_with_packing(lstm: nn.LSTM, x: torch.Tensor, lengths: torch.Tensor):
30    """
31    Efficiently process variable-length sequences.
32    """
33    # Pack: skip padded positions during computation
34    packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True)
35
36    # Process
37    packed_output, (h_n, c_n) = lstm(packed)
38
39    # Unpack: restore padding for downstream operations
40    output, _ = pad_packed_sequence(packed_output, batch_first=True)
41
42    return output, (h_n, c_n)
43
44
45# Example with variable-length sequences
46sequences = [
47    torch.randn(15, 64),  # Length 15
48    torch.randn(10, 64),  # Length 10
49    torch.randn(22, 64),  # Length 22
50    torch.randn(8, 64),   # Length 8
51]
52
53padded, lengths, indices = prepare_batch(sequences)
54print(f"Padded shape: {padded.shape}")  # (4, 22, 64) - padded to longest
55print(f"Lengths: {lengths}")  # tensor([22, 15, 10, 8])

Debugging LSTM Networks

Here are common issues and how to diagnose them:

1. Loss Not Decreasing

SymptomPossible CauseSolution
Loss flat from startLearning rate too lowIncrease LR by 10x
Loss oscillates wildlyLearning rate too highDecrease LR, add gradient clipping
Loss decreases then plateaus earlyModel capacity too lowAdd layers or increase hidden size
NaN lossGradient explosionAdd gradient clipping, lower LR

2. Gradient Issues

Gradient Diagnostics
🐍debug_gradients.py
6Gradient Norm

The L2 norm of all gradients gives a single number summarizing gradient magnitude. Track this during training.

17Vanishing Detection

Gradient norm < 1e-6 indicates vanishing gradients. The model is learning very slowly or not at all.

21Exploding Detection

Gradient norm > 1000 indicates explosion. Training will be unstable and may produce NaN.

31 lines without explanation
1def diagnose_gradients(model: nn.Module, loss: torch.Tensor):
2    """
3    Check for vanishing/exploding gradients.
4    """
5    loss.backward()
6
7    total_norm = 0.0
8    param_norms = {}
9
10    for name, param in model.named_parameters():
11        if param.grad is not None:
12            param_norm = param.grad.data.norm(2).item()
13            param_norms[name] = param_norm
14            total_norm += param_norm ** 2
15
16    total_norm = total_norm ** 0.5
17
18    print(f"Total gradient norm: {total_norm:.4f}")
19
20    # Check for issues
21    if total_norm < 1e-6:
22        print("⚠️ WARNING: Vanishing gradients detected!")
23        print("   Consider: skip connections, different initialization, shorter sequences")
24    elif total_norm > 1000:
25        print("⚠️ WARNING: Exploding gradients detected!")
26        print("   Consider: gradient clipping, lower learning rate")
27    else:
28        print("✓ Gradient norm looks healthy")
29
30    # Print per-layer norms
31    print("\nPer-parameter gradient norms:")
32    for name, norm in sorted(param_norms.items(), key=lambda x: -x[1])[:5]:
33        status = "🔴" if norm < 1e-6 else ("🔴" if norm > 100 else "🟢")
34        print(f"  {status} {name}: {norm:.4f}")

3. Memory Issues

  • Out of Memory: Reduce batch size, sequence length, or hidden size. Consider gradient accumulation.
  • Slow Training: Use cuDNN (default with CUDA), enable mixed precision training.
  • Memory Leak: Detach hidden states when not backpropagating through full sequence.

Detach Hidden States for Long Sequences

When training on very long sequences or continuous text, you need to detach hidden states to prevent the computation graph from growing indefinitely:

Truncated Backpropagation Through Time
🐍truncated_bptt.py
12Detach Before Reuse

Without detaching, the computation graph grows with each batch. Memory usage will increase indefinitely until OOM.

14 lines without explanation
1def detach_state(state):
2    """Detach hidden states from computation graph."""
3    if state is None:
4        return None
5    h, c = state
6    return (h.detach(), c.detach())
7
8# Training loop with truncated BPTT
9state = None
10for batch in data_loader:
11    if state is not None:
12        state = detach_state(state)  # Crucial!
13    output, state = model(batch, state)
14    loss = criterion(output, target)
15    loss.backward()

Summary

In this section, we implemented LSTM from scratch and learned how to train it effectively. Key takeaways:

Implementation Insights

ConceptKey PointPractical Impact
Gate combining4 gates in 1 matmul3-4x speedup on GPU
Weight initializationOrthogonal + forget bias=1Faster convergence
Gradient clippingClip norm to 1-5Prevents explosion
Sequence packingSkip padded positionsFaster variable-length training
State detachmentDetach for long sequencesPrevents memory leak

Key Code Patterns

  1. LSTM Cell: (xt,ht1,Ct1)(ht,Ct)(x_t, h_{t-1}, C_{t-1}) \to (h_t, C_t)
  2. LSTM Layer: Iterate cell over sequence, collect outputs
  3. Efficient Gates: gates=xWihT+hWhhT+b\text{gates} = x W_{ih}^T + h W_{hh}^T + b, then chunk
  4. Training: CrossEntropy + Adam + gradient clipping + LR scheduling

Looking Forward

In the next section, we'll study the Gated Recurrent Unit (GRU)—a simpler alternative to LSTM that often achieves similar performance with fewer parameters.


Knowledge Check

Test your understanding of LSTM implementation details:

Knowledge CheckQuestion 1 of 8
What is the main advantage of combining all 4 gate computations into a single matrix multiplication?
Score: 0/0

Exercises

Implementation Exercises

  1. Bidirectional LSTM: Extend the LSTMLayer class to support bidirectional processing. Run the sequence through a forward and backward LSTM, then concatenate their outputs.
  2. Stacked LSTM: Implement a multi-layer LSTM by stacking LSTMLayer instances. The output of layer ii becomes the input to layer i+1i+1.
  3. Peephole Connections: Modify the LSTM cell to include peephole connections, where gates also look at the cell state: ft=σ(Wf[ht1,xt]+WcfCt1+bf)f_t = \sigma(W_f [h_{t-1}, x_t] + W_{cf} \odot C_{t-1} + b_f)
  4. Layer Normalization: Add layer normalization to the LSTM cell. Normalize the gate activations before applying sigmoid/tanh.

Training Exercises

  1. Shakespeare Generator: Train a character-level LSTM on Shakespeare text. Generate 500 characters after training and analyze the output.
  2. Sentiment Analysis: Build an LSTM classifier for IMDB movie reviews. Compare performance with 1, 2, and 3 layers.
  3. Sequence Length Study: Train LSTMs on the copy task with delays of 10, 50, 100, and 200 steps. Plot accuracy vs. delay length.

Analysis Exercises

  1. Gate Visualization: For a trained LSTM, visualize the forget gate activations while processing a sentence. Which words trigger forgetting?
  2. Cell State Dynamics: Track cell state values over time while processing text. Identify which dimensions encode what information.
  3. Gradient Flow Comparison: Compare gradient norms at different timesteps for vanilla RNN vs. LSTM on a 100-step sequence.

Exercise Hints

  • Bidirectional: Create two separate LSTMLayer instances, reverse the sequence for the backward one
  • Shakespeare: Use sequences of 100-200 characters, hidden size 512, 2-3 layers
  • Gate visualization: Register forward hooks to capture gate activations during inference

Challenge Project

Build a Mini-Transformer from LSTM: Implement attention over LSTM hidden states to create a simple encoder-decoder model for translation:

  • Encoder: Bidirectional LSTM that processes source sentence
  • Attention: Compute attention weights over encoder states
  • Decoder: LSTM that generates target sentence using attention context
  • Train on a small parallel corpus (e.g., Multi30k English-German)

Now that you can implement and train LSTMs, you're ready to learn about the Gated Recurrent Unit (GRU)—a more recent architecture that simplifies LSTM while maintaining its key benefits. We'll compare the two architectures and discuss when to use each.