Implement an LSTM cell from scratch using only basic PyTorch tensor operations
Build a full LSTM layer that processes sequences of any length
Understand the efficiency tricks used in production LSTM implementations
Use PyTorch's built-in LSTM and understand its parameters and outputs
Train an LSTM on a real sequence prediction task from start to finish
Debug common LSTM issues including vanishing gradients, exploding gradients, and memory problems
Why This Matters: Understanding the implementation details of LSTM is crucial for several reasons: (1) you'll gain deep intuition about what each component does by building it yourself, (2) you'll be able to debug issues that arise in production, (3) you'll understand the trade-offs in different implementation choices, and (4) you'll be prepared to modify or extend LSTM for specialized applications. This section bridges the gap between mathematical understanding and practical application.
The Implementation Journey
In the previous section, we understood what LSTM does and why it works. Now we'll learn how to build one. We'll start from the ground up, implementing each component step by step.
Our Implementation Roadmap
Step
Component
Purpose
1
Single LSTM Cell
Process one timestep: (xₜ, hₜ₋₁, Cₜ₋₁) → (hₜ, Cₜ)
2
LSTM Layer
Process entire sequence by iterating the cell
3
Bidirectional LSTM
Process sequence in both directions
4
Stacked LSTM
Multiple LSTM layers for deeper representations
Let's start with the most fundamental component: the LSTM cell.
Building an LSTM Cell from Scratch
An LSTM cell takes three inputs and produces two outputs:
Inputs: Current input xt, previous hidden state ht−1, previous cell state Ct−1
Outputs: New hidden state ht, new cell state Ct
Internally, it computes four components: forget gate, input gate, candidate cell state, and output gate. Let's implement this step by step.
LSTM Cell: The Clear Implementation
🐍lstm_cell_scratch.py
Explanation(11)
Code(123)
5Class Definition
We inherit from nn.Module to get automatic parameter registration and gradient tracking.
18Forget Gate Weights
Wf has shape (hidden_size, hidden_size + input_size) because it takes the concatenation of hₜ₋₁ and xₜ as input.
36Xavier Initialization
Xavier (Glorot) initialization helps prevent vanishing/exploding gradients at the start of training by keeping variance stable across layers.
44Forget Gate Bias = 1
Initializing forget gate bias to 1 makes the initial forget gate output close to 1 (after sigmoid). This encourages the LSTM to remember information by default, which helps with learning long-term dependencies.
69Concatenation
We concatenate hₜ₋₁ and xₜ into a single vector. This allows us to use a single weight matrix for each gate instead of two separate matrices.
73Forget Gate
fₜ = σ(Wf @ [h,x] + bf). Sigmoid outputs values in [0,1], where 0 means 'forget everything' and 1 means 'remember everything'.
76Input Gate
iₜ = σ(Wᵢ @ [h,x] + bᵢ). Controls how much of the new candidate information to add to the cell state.
79Candidate Cell State
C̃ₜ = tanh(Wc @ [h,x] + bc). Tanh outputs values in [-1,1], allowing both addition and subtraction from the cell state.
82Output Gate
oₜ = σ(Wₒ @ [h,x] + bₒ). Controls which parts of the cell state are exposed in the hidden state output.
87Cell State Update (Key Equation!)
Cₜ = fₜ × Cₜ₋₁ + iₜ × C̃ₜ. The ADDITION is what enables gradient flow. When f≈1 and i≈0, Cₜ ≈ Cₜ₋₁ and gradients pass through unchanged.
90Hidden State Output
hₜ = oₜ × tanh(Cₜ). The cell state is squashed by tanh and filtered by the output gate. This is what the LSTM 'exposes' to the outside world.
112 lines without explanation
1import torch
2import torch.nn as nn
3from typing import Optional
45classLSTMCellFromScratch(nn.Module):6"""
7 LSTM cell implemented from first principles.
89 This implementation prioritizes clarity over efficiency.
10 Each gate is computed separately to match the mathematical formulation.
11 """1213def__init__(self, input_size:int, hidden_size:int):14super().__init__()15 self.input_size = input_size
16 self.hidden_size = hidden_size
1718# Forget gate parameters: W_f @ [h, x] + b_f19 self.W_f = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))20 self.b_f = nn.Parameter(torch.zeros(hidden_size))2122# Input gate parameters: W_i @ [h, x] + b_i23 self.W_i = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))24 self.b_i = nn.Parameter(torch.zeros(hidden_size))2526# Candidate cell state parameters: W_c @ [h, x] + b_c27 self.W_c = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))28 self.b_c = nn.Parameter(torch.zeros(hidden_size))2930# Output gate parameters: W_o @ [h, x] + b_o31 self.W_o = nn.Parameter(torch.randn(hidden_size, hidden_size + input_size))32 self.b_o = nn.Parameter(torch.zeros(hidden_size))3334# Initialize weights using Xavier initialization35 self._init_weights()3637def_init_weights(self):38"""Initialize weights for stable training."""39 std =1.0/(self.hidden_size **0.5)40for param in self.parameters():41if param.dim()>1:42 nn.init.xavier_uniform_(param)43else:44 nn.init.zeros_(param)4546# Bias forget gate towards remembering (common practice)47 nn.init.ones_(self.b_f)4849defforward(50 self,51 x: torch.Tensor,52 state: Optional[tuple[torch.Tensor, torch.Tensor]]=None53)->tuple[torch.Tensor, torch.Tensor]:54"""
55 Forward pass for one timestep.
5657 Args:
58 x: Input tensor of shape (batch_size, input_size)
59 state: Tuple of (h_prev, c_prev), each (batch_size, hidden_size)
60 If None, initializes to zeros.
6162 Returns:
63 Tuple of (h_t, c_t), each (batch_size, hidden_size)
64 """65 batch_size = x.size(0)6667# Initialize state if not provided68if state isNone:69 h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)70 c_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)71else:72 h_prev, c_prev = state
7374# Concatenate h_{t-1} and x_t: [h, x]75 combined = torch.cat([h_prev, x], dim=1)7677# ===== Gate Computations =====7879# Forget gate: f_t = σ(W_f @ [h, x] + b_f)80 f_t = torch.sigmoid(combined @ self.W_f.T + self.b_f)8182# Input gate: i_t = σ(W_i @ [h, x] + b_i)83 i_t = torch.sigmoid(combined @ self.W_i.T + self.b_i)8485# Candidate cell state: C̃_t = tanh(W_c @ [h, x] + b_c)86 c_tilde = torch.tanh(combined @ self.W_c.T + self.b_c)8788# Output gate: o_t = σ(W_o @ [h, x] + b_o)89 o_t = torch.sigmoid(combined @ self.W_o.T + self.b_o)9091# ===== State Updates =====9293# Cell state update: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t94 c_t = f_t * c_prev + i_t * c_tilde
9596# Hidden state output: h_t = o_t ⊙ tanh(C_t)97 h_t = o_t * torch.tanh(c_t)9899return h_t, c_t
100101102# Test the implementation103deftest_lstm_cell():104 batch_size =4105 input_size =10106 hidden_size =20107108 cell = LSTMCellFromScratch(input_size, hidden_size)109110# Single input111 x = torch.randn(batch_size, input_size)112 h, c = cell(x)113114print(f"Input shape: {x.shape}")115print(f"Hidden state shape: {h.shape}")116print(f"Cell state shape: {c.shape}")117118# Verify shapes119assert h.shape ==(batch_size, hidden_size)120assert c.shape ==(batch_size, hidden_size)121print("✓ All shape checks passed!")122123test_lstm_cell()
This Implementation Is Intentionally Slow
The above implementation uses four separate matrix multiplications (one per gate). Production implementations combine these into a single matrix multiply for efficiency. We'll show that optimization next.
Step-by-Step Implementation Walkthrough
Let's trace through what happens when we process a single timestep with concrete numbers.
Example: Processing One Timestep
Suppose we have:
Input size: 3 (e.g., 3-dimensional word embedding)
Hidden size: 4
Current input:xt=[0.5,−0.2,0.8]
Previous hidden state:ht−1=[0.1,0.3,−0.1,0.2]
Previous cell state:Ct−1=[1.0,−0.5,0.3,0.7]
Step 1: Concatenation
First, we concatenate ht−1 and xt:
[ht−1,xt]=[0.1,0.3,−0.1,0.2,0.5,−0.2,0.8]
This combined vector has dimension 4+3=7.
Step 2: Gate Computations
Each gate computes a linear transformation followed by an activation:
Gate
Formula
Activation
Output Range
Forget (fₜ)
Wf @ [h,x] + bf
Sigmoid
[0, 1]
Input (iₜ)
Wᵢ @ [h,x] + bᵢ
Sigmoid
[0, 1]
Candidate (C̃ₜ)
Wc @ [h,x] + bc
Tanh
[-1, 1]
Output (oₜ)
Wₒ @ [h,x] + bₒ
Sigmoid
[0, 1]
Step 3: State Updates
The cell state is updated using element-wise operations:
Ct=ft⊙Ct−1+it⊙C~t
Suppose after computing the gates we get:
ft=[0.9,0.1,0.8,0.95] (keep most of C[0], C[2], C[3]; forget C[1])
it=[0.3,0.9,0.2,0.1] (add a lot to C[1], little to others)
C~t=[0.5,0.8,−0.3,0.1] (candidate values to potentially add)
Notice how C[1] changed dramatically: the forget gate was low (0.1) so the old value (-0.5) was mostly forgotten, and the input gate was high (0.9) so the new candidate (0.8) was added. Meanwhile, C[0] and C[3] barely changed because their forget gates were high and input gates were low.
Quick Check
If the forget gate outputs fₜ = [1, 1, 1, 1] and the input gate outputs iₜ = [0, 0, 0, 0], what is Cₜ?
Interactive: LSTM Data Flow Visualization
Before diving into code, let's visualize exactly how data flows through an LSTM cell. This interactive diagram shows the step-by-step calculations with actual numbers. Watch how the Long-Term Memory (cell state) flows along the top, while Short-Term Memory (hidden state) flows along the bottom.
LSTM Data Flow: Step-by-Step Calculations
Watch data flow through the LSTM cell with actual calculations. Click on any step to see the math!
Key Insight: Notice how the cell state update uses addition: Ct = f × Ct-1 + i × C̃. This additive structure is what enables gradient flow! When the forget gate f ≈ 1 and input gate i ≈ 0, the cell state passes through almost unchanged, allowing gradients to flow backward through many timesteps without vanishing.
Building a Full LSTM Layer
An LSTM cell processes one timestep. To process an entire sequence, we need an LSTM layer that iterates the cell across all timesteps. We'll also implement the efficient version that combines all gate computations.
Efficient LSTM Layer Implementation
🐍lstm_layer.py
Explanation(9)
Code(159)
5LSTM Layer Class
An LSTM layer wraps the cell logic and adds sequence processing. The batch_first parameter controls tensor shape conventions.
27Combined Weight Matrices
Instead of 4 separate weight matrices, we use one matrix of size (4*hidden, input) for input weights and one of size (4*hidden, hidden) for recurrent weights. This enables a single GEMM operation per timestep.
40Initialization Strategy
We use Xavier for input weights and orthogonal for hidden weights. Orthogonal initialization keeps eigenvalues close to 1, which helps gradient flow in recurrent connections.
48Forget Gate Bias
Setting the forget gate bias to 1 is a crucial initialization trick. It makes the LSTM remember by default, which is especially important early in training when the network hasn't learned what to forget.
68Transpose for seq_first
We internally process sequences in (seq, batch, features) format for efficiency, but allow users to pass (batch, seq, features) with batch_first=True.
87Efficient Gate Computation
This is the key optimization: computing all 4 gates in a single matrix multiplication. We get 4x speedup compared to computing each gate separately.
94Chunk Operation
torch.chunk splits the combined gate output into 4 equal parts along the hidden dimension. Each part corresponds to one gate: forget, input, candidate, output.
103State Updates
Standard LSTM equations: cell state uses additive update (enables gradient flow), hidden state uses multiplicative gating of cell state.
112Dropout
Dropout is applied to outputs, not recurrent connections. Applying dropout to recurrent connections can hurt performance (see 'variational dropout' for alternatives).
150 lines without explanation
1import torch
2import torch.nn as nn
3from typing import Optional
45classLSTMLayer(nn.Module):6"""
7 Efficient LSTM layer that processes full sequences.
89 This implementation combines all four gate computations into
10 a single matrix multiplication for 4x speedup.
11 """1213def__init__(14 self,15 input_size:int,16 hidden_size:int,17 batch_first:bool=True,18 dropout:float=0.0,19):20super().__init__()21 self.input_size = input_size
22 self.hidden_size = hidden_size
23 self.batch_first = batch_first
24 self.dropout = dropout
2526# Combined weights for all 4 gates: [f, i, c, o]27# Shape: (4 * hidden_size, input_size) for input weights28# Shape: (4 * hidden_size, hidden_size) for hidden weights29 self.weight_ih = nn.Parameter(30 torch.randn(4* hidden_size, input_size)31)32 self.weight_hh = nn.Parameter(33 torch.randn(4* hidden_size, hidden_size)34)35 self.bias = nn.Parameter(torch.zeros(4* hidden_size))3637# Dropout layer (applied to outputs, not recurrent connections)38 self.drop = nn.Dropout(dropout)if dropout >0else nn.Identity()3940 self._init_weights()4142def_init_weights(self):43"""Initialize weights using orthogonal initialization."""44# Xavier/Glorot for input weights45 nn.init.xavier_uniform_(self.weight_ih)4647# Orthogonal for hidden weights (helps with gradient flow)48 nn.init.orthogonal_(self.weight_hh)4950# Bias: zeros, except forget gate bias = 151 nn.init.zeros_(self.bias)52# Forget gate bias (first hidden_size elements)53 self.bias.data[:self.hidden_size].fill_(1.0)5455defforward(56 self,57 x: torch.Tensor,58 state: Optional[tuple[torch.Tensor, torch.Tensor]]=None,59)->tuple[torch.Tensor,tuple[torch.Tensor, torch.Tensor]]:60"""
61 Process an entire sequence.
6263 Args:
64 x: Input sequence
65 - If batch_first: (batch, seq_len, input_size)
66 - Otherwise: (seq_len, batch, input_size)
67 state: Initial (h_0, c_0), each of shape (batch, hidden_size)
6869 Returns:
70 outputs: All hidden states, shape matching input
71 (h_n, c_n): Final hidden and cell states
72 """73# Handle batch_first74if self.batch_first:75 x = x.transpose(0,1)# (seq, batch, input)7677 seq_len, batch_size, _ = x.shape
7879# Initialize state if not provided80if state isNone:81 h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)82 c_t = torch.zeros(batch_size, self.hidden_size, device=x.device)83else:84 h_t, c_t = state
8586# Collect outputs87 outputs =[]8889# Process sequence timestep by timestep90for t inrange(seq_len):91 x_t = x[t]# (batch, input_size)9293# === Efficient Gate Computation ===94# Single matrix multiply computes all 4 gates at once95# gates = x_t @ W_ih^T + h_t @ W_hh^T + bias96 gates =(97 x_t @ self.weight_ih.T +98 h_t @ self.weight_hh.T +99 self.bias
100)101102# Split into 4 gates, each of shape (batch, hidden_size)103 f_gate, i_gate, c_gate, o_gate = gates.chunk(4, dim=1)104105# Apply activations106 f_t = torch.sigmoid(f_gate)107 i_t = torch.sigmoid(i_gate)108 c_tilde = torch.tanh(c_gate)109 o_t = torch.sigmoid(o_gate)110111# Update states112 c_t = f_t * c_t + i_t * c_tilde
113 h_t = o_t * torch.tanh(c_t)114115 outputs.append(h_t)116117# Stack outputs: (seq_len, batch, hidden_size)118 output = torch.stack(outputs, dim=0)119120# Apply dropout to output (not recurrent connections)121 output = self.drop(output)122123# Handle batch_first for output124if self.batch_first:125 output = output.transpose(0,1)# (batch, seq, hidden)126127return output,(h_t, c_t)128129130# Test and benchmark131deftest_lstm_layer():132 batch_size =32133 seq_len =100134 input_size =64135 hidden_size =128136137 lstm = LSTMLayer(input_size, hidden_size, batch_first=True)138139# Random sequence140 x = torch.randn(batch_size, seq_len, input_size)141142# Forward pass143 output,(h_n, c_n)= lstm(x)144145print(f"Input shape: {x.shape}")146print(f"Output shape: {output.shape}")147print(f"Final hidden state shape: {h_n.shape}")148print(f"Final cell state shape: {c_n.shape}")149150# Verify shapes151assert output.shape ==(batch_size, seq_len, hidden_size)152assert h_n.shape ==(batch_size, hidden_size)153assert c_n.shape ==(batch_size, hidden_size)154155# Verify output[-1] == h_n (last output is final hidden state)156assert torch.allclose(output[:,-1,:], h_n, atol=1e-6)157print("✓ All checks passed!")158159test_lstm_layer()
Why Combine Gates?
Combining all gates into a single matrix multiplication provides significant speedup because:
GPU parallelism: Large matrix multiplications are highly parallelized on GPUs
Memory efficiency: One large matrix operation has less overhead than four small ones
Cache efficiency: Better memory access patterns when reading input once for all gates
On a modern GPU, this optimization can provide 3-4x speedup over the naive implementation.
Interactive: LSTM Implementation Explorer
Explore how data flows through the LSTM cell. Adjust inputs and observe how the gates and states change in real-time.
LSTM Cell Implementation Explorer
Progress:
Ready
Forget Gate
sigmoid
Pre:1.050
Post:0.741
Input Gate
sigmoid
Pre:0.470
Post:0.615
Candidate
tanh
Pre:0.480
Post:0.446
Output Gate
sigmoid
Pre:0.510
Post:0.625
State Updates
Cell State Update
Ct = ft × Ct-1 + it × C̃t
0.741×0.700+0.615×0.446=0.793
Hidden State Output
ht = ot × tanh(Ct)
0.625×tanh(0.793)=0.412
Key Insight
Notice the addition in the cell state update: Ct = f × Ct-1+ i × C̃. This is what enables gradient flow! When f ≈ 1 and i ≈ 0, the gradient flows unchanged through Ct → Ct-1. Try setting the input slider to 0 and observe how the forget gate dominates.
PyTorch's Built-in LSTM
PyTorch provides a highly optimized LSTM implementation. Let's understand how to use it and verify our implementation matches it.
Using PyTorch's Built-in LSTM
🐍pytorch_lstm.py
Explanation(7)
Code(118)
19nn.LSTM Parameters
PyTorch's LSTM supports multiple layers, bidirectional processing, and dropout between layers—all in one class.
22batch_first=True
When True, expects input as (batch, seq, features). Default is False: (seq, batch, features). Match your data format.
23Dropout Between Layers
This dropout is applied between LSTM layers (for num_layers > 1), NOT after the last layer. Add your own dropout after if needed.
47Packed Sequences
For variable-length sequences, packing avoids wasted computation on padding. The LSTM skips padded positions entirely.
61hₙ Shape
hₙ has shape (num_layers × num_directions, batch, hidden). For multi-layer bidirectional: [L0_fwd, L0_bwd, L1_fwd, L1_bwd, ...].
65Bidirectional Final State
For bidirectional LSTM, concatenate forward and backward final hidden states. hₙ[-2] is last layer forward, hₙ[-1] is last layer backward.
96Output vs hₙ
output contains hidden states for ALL timesteps. hₙ contains only the FINAL hidden state per layer. For classification, you typically use hₙ.
111 lines without explanation
1import torch
2import torch.nn as nn
34# PyTorch's built-in LSTM5classModelWithBuiltinLSTM(nn.Module):6"""
7 Example model using PyTorch's nn.LSTM.
8 """910def__init__(11 self,12 input_size:int=64,13 hidden_size:int=128,14 num_layers:int=2,15 dropout:float=0.1,16 bidirectional:bool=False,17):18super().__init__()1920 self.lstm = nn.LSTM(21 input_size=input_size,22 hidden_size=hidden_size,23 num_layers=num_layers,24 batch_first=True,# (batch, seq, features)25 dropout=dropout,# Between layers (not after last)26 bidirectional=bidirectional,27)2829# Calculate output size30 self.output_size = hidden_size *(2if bidirectional else1)3132# Final projection layer33 self.fc = nn.Linear(self.output_size,10)# e.g., 10 classes3435defforward(36 self,37 x: torch.Tensor,38 lengths: Optional[torch.Tensor]=None,39)->tuple[torch.Tensor, torch.Tensor]:40"""
41 Args:
42 x: (batch, seq_len, input_size)
43 lengths: Optional sequence lengths for packing
4445 Returns:
46 logits: (batch, 10) class scores
47 hidden: Final hidden state
48 """49# Handle variable-length sequences efficiently50if lengths isnotNone:51# Pack for efficient computation52 x_packed = nn.utils.rnn.pack_padded_sequence(53 x, lengths.cpu(), batch_first=True, enforce_sorted=False54)55 output_packed,(h_n, c_n)= self.lstm(x_packed)56# Unpack57 output, _ = nn.utils.rnn.pad_packed_sequence(58 output_packed, batch_first=True59)60else:61 output,(h_n, c_n)= self.lstm(x)6263# h_n shape: (num_layers * num_directions, batch, hidden)64# For classification, typically use the last layer's hidden state6566if self.lstm.bidirectional:67# Concatenate forward and backward final hidden states68# h_n[-2] is forward, h_n[-1] is backward69 hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)70else:71 hidden = h_n[-1]# Last layer hidden state7273 logits = self.fc(hidden)74return logits, hidden
757677# Understanding LSTM outputs78defunderstand_lstm_outputs():79"""
80 Demystify nn.LSTM output shapes.
81 """82 batch_size =483 seq_len =1084 input_size =3285 hidden_size =6486 num_layers =38788 lstm = nn.LSTM(89 input_size, hidden_size,90 num_layers=num_layers,91 batch_first=True,92 bidirectional=True93)9495 x = torch.randn(batch_size, seq_len, input_size)96 output,(h_n, c_n)= lstm(x)9798print("=== nn.LSTM Output Shapes ===")99print(f"Input: {x.shape}")100print(f" → (batch={batch_size}, seq={seq_len}, input={input_size})")101print()102print(f"Output: {output.shape}")103print(f" → (batch={batch_size}, seq={seq_len}, hidden*directions={hidden_size*2})")104print(f" → Contains hidden states for ALL timesteps")105print()106print(f"h_n: {h_n.shape}")107print(f" → (layers*directions={num_layers*2}, batch={batch_size}, hidden={hidden_size})")108print(f" → Contains FINAL hidden state for each layer")109print()110print(f"c_n: {c_n.shape}")111print(f" → (layers*directions={num_layers*2}, batch={batch_size}, hidden={hidden_size})")112print(f" → Contains FINAL cell state for each layer")113114# Key insight: output[:, -1, :] contains the last timestep115# but for bidirectional, you need output[:, -1, :hidden] for forward116# and output[:, 0, hidden:] for backward (processes in reverse)117118understand_lstm_outputs()
Common Confusion: output vs hₙ
Output
Shape
Contains
output
(batch, seq, hidden×dir)
Hidden state at EVERY timestep
hₙ
(layers×dir, batch, hidden)
FINAL hidden state per layer
cₙ
(layers×dir, batch, hidden)
FINAL cell state per layer
For sequence classification, use hₙ (final state). For sequence-to-sequence tasks (like translation), use output (all states).
Training LSTM on a Real Task
Let's train an LSTM on a practical sequence prediction task: predicting the next character in a sequence. This task clearly demonstrates long-term dependencies.
Training a Character-Level LSTM Language Model
🐍train_lstm.py
Explanation(9)
Code(210)
8Character-Level Model
Character-level models predict the next character given previous characters. They can learn spelling, syntax, and even generate reasonable text without explicit word tokenization.
25Embedding Layer
Converts character indices to dense vectors. The model learns these embeddings during training.
41Weight Tying
An optional technique where the output projection shares weights with the embedding. This reduces parameters and often improves performance.
57Dropout on Embeddings
Apply dropout to embeddings during training. This regularizes the model and prevents overfitting to specific characters.
85Temperature Sampling
Temperature controls randomness: T=1.0 is standard, T<1.0 is more deterministic (sharper peaks), T>1.0 is more random (flatter distribution).
96Autoregressive Generation
We feed each generated character back as input to generate the next one. The state is passed along to maintain context.
123Learning Rate Scheduler
ReduceLROnPlateau decreases the learning rate when validation loss stops improving, helping the model converge better.
149Gradient Clipping
Essential for RNNs! Clips gradients to prevent explosion. Typical values are 1.0-5.0. Without this, LSTM training often diverges.
180Perplexity
Perplexity = exp(loss) measures how 'surprised' the model is. Lower is better. PPL of 10 means the model is as uncertain as choosing uniformly from 10 options.
201 lines without explanation
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader, Dataset
5import numpy as np
6from typing import Optional
78classCharacterLSTM(nn.Module):9"""
10 Character-level language model using LSTM.
11 Predicts the next character given previous characters.
12 """1314def__init__(15 self,16 vocab_size:int,17 embedding_dim:int=64,18 hidden_size:int=256,19 num_layers:int=2,20 dropout:float=0.2,21):22super().__init__()23 self.hidden_size = hidden_size
24 self.num_layers = num_layers
2526# Character embedding27 self.embedding = nn.Embedding(vocab_size, embedding_dim)2829# LSTM layers30 self.lstm = nn.LSTM(31 input_size=embedding_dim,32 hidden_size=hidden_size,33 num_layers=num_layers,34 batch_first=True,35 dropout=dropout if num_layers >1else0,36)3738# Output projection39 self.fc = nn.Linear(hidden_size, vocab_size)4041# Dropout for regularization42 self.drop = nn.Dropout(dropout)4344# Weight tying (optional but helpful)45# self.fc.weight = self.embedding.weight4647defforward(48 self,49 x: torch.Tensor,50 state: Optional[tuple[torch.Tensor, torch.Tensor]]=None,51)->tuple[torch.Tensor,tuple[torch.Tensor, torch.Tensor]]:52"""
53 Args:
54 x: Character indices, shape (batch, seq_len)
55 state: Optional (h, c) for continuation
5657 Returns:
58 logits: (batch, seq_len, vocab_size)
59 state: (h_n, c_n) for next forward pass
60 """61# Embed characters62 embedded = self.drop(self.embedding(x))# (batch, seq, embed)6364# LSTM forward65 lstm_out, state = self.lstm(embedded, state)# (batch, seq, hidden)6667# Project to vocabulary68 logits = self.fc(self.drop(lstm_out))# (batch, seq, vocab)6970return logits, state
7172defgenerate(73 self,74 start_chars:str,75 char_to_idx:dict,76 idx_to_char:dict,77 length:int=100,78 temperature:float=1.0,79 device:str="cpu",80)->str:81"""
82 Generate text character by character.
8384 Args:
85 start_chars: Seed string to start generation
86 char_to_idx: Vocabulary mapping
87 idx_to_char: Inverse vocabulary mapping
88 length: Number of characters to generate
89 temperature: Sampling temperature (higher = more random)
90 device: Device to run on
91 """92 self.eval()9394# Convert start chars to tensor95 chars =[char_to_idx.get(c,0)for c in start_chars]96 x = torch.tensor([chars], device=device)9798 generated =list(start_chars)99 state =None100101with torch.no_grad():102# Process initial context103 logits, state = self(x, state)104105for _ inrange(length):106# Get next character prediction (last position)107 next_logits = logits[0,-1,:]/ temperature
108109# Sample from distribution110 probs = torch.softmax(next_logits, dim=0)111 next_idx = torch.multinomial(probs,1).item()112113# Append to generated text114 next_char = idx_to_char[next_idx]115 generated.append(next_char)116117# Prepare next input118 x = torch.tensor([[next_idx]], device=device)119 logits, state = self(x, state)120121return''.join(generated)122123124# Training loop125deftrain_character_lstm(126 model: CharacterLSTM,127 train_loader: DataLoader,128 val_loader: DataLoader,129 epochs:int=10,130 lr:float=0.001,131 clip_grad:float=5.0,132 device:str="cpu",133):134"""
135 Complete training loop with gradient clipping and validation.
136 """137 model = model.to(device)138 optimizer = optim.Adam(model.parameters(), lr=lr)139 scheduler = optim.lr_scheduler.ReduceLROnPlateau(140 optimizer, mode='min', factor=0.5, patience=2141)142 criterion = nn.CrossEntropyLoss()143144 best_val_loss =float('inf')145146for epoch inrange(epochs):147# Training148 model.train()149 train_loss =0.0150151for batch_idx,(x, y)inenumerate(train_loader):152 x, y = x.to(device), y.to(device)153154 optimizer.zero_grad()155156# Forward pass157 logits, _ = model(x)158159# Reshape for loss: (batch * seq, vocab) vs (batch * seq)160 loss = criterion(161 logits.view(-1, logits.size(-1)),162 y.view(-1)163)164165# Backward pass166 loss.backward()167168# Gradient clipping (crucial for RNNs!)169 torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)170171 optimizer.step()172173 train_loss += loss.item()174175 avg_train_loss = train_loss /len(train_loader)176177# Validation178 model.eval()179 val_loss =0.0180181with torch.no_grad():182for x, y in val_loader:183 x, y = x.to(device), y.to(device)184 logits, _ = model(x)185 loss = criterion(186 logits.view(-1, logits.size(-1)),187 y.view(-1)188)189 val_loss += loss.item()190191 avg_val_loss = val_loss /len(val_loader)192193# Learning rate scheduling194 scheduler.step(avg_val_loss)195196# Compute perplexity197 train_ppl = np.exp(avg_train_loss)198 val_ppl = np.exp(avg_val_loss)199200print(f"Epoch {epoch+1}/{epochs}")201print(f" Train Loss: {avg_train_loss:.4f}, PPL: {train_ppl:.2f}")202print(f" Val Loss: {avg_val_loss:.4f}, PPL: {val_ppl:.2f}")203204# Save best model205if avg_val_loss < best_val_loss:206 best_val_loss = avg_val_loss
207 torch.save(model.state_dict(),"best_char_lstm.pt")208print(f" ✓ Saved best model (val_loss: {avg_val_loss:.4f})")209210return model
Gradient Clipping Is Essential
Even with LSTM's improved gradient flow, gradients can still explode on difficult sequences. Always use gradient clipping when training RNNs:
clip_grad=1.0 to 5.0 are typical values
Monitor gradient norms during training to tune this value
If training is unstable, try reducing clip value
Interactive: Training Dashboard
Observe LSTM training in action. Watch how the loss decreases, gate activations evolve, and generated text improves over epochs.
LSTM Training Dashboard
Speed:
Epoch 1/10Step 0/100
Train Loss
4.488
Val Loss
4.932
Perplexity
89.0
Grad Norm
2.70
Loss History
TrainValidation
Gate Activations (10 hidden units)
Forget Gate (fₜ)
Input Gate (iₜ)
Output Gate (oₜ)
As training progresses, gate activations become more structured and specialized. Watch how the patterns evolve from random to meaningful.
Sample Generation
Prompt: "The cat" → thxe caat sxt on the maxt
Text quality improves as the model learns character-level patterns and long-term dependencies.
Learning Rate1.00e-3
Comparing RNN vs LSTM Performance
Let's empirically compare vanilla RNN and LSTM on tasks requiring different memory lengths.
RNN vs LSTM: The Memory Test
The Copy Task
The network must memorize a sequence of random bits, wait through a delay period, then reproduce the sequence exactly. As the delay increases, RNN performance collapses while LSTM maintains its accuracy.
RNN: Gradient magnitude = (Wₕₕ)ᵀ per timestep. With ||Wₕₕ|| = 0.85, after 100 steps: 0.85¹⁰⁰ ≈ 10⁻⁷. The gradient has effectively vanished.
LSTM: Gradient magnitude ≈ (fₜ)ᵀ through cell state. With fₜ ≈ 0.98, after 100 steps: 0.98¹⁰⁰ ≈ 0.13. The gradient remains usable for learning.
The visualization above demonstrates the key advantage of LSTM: as the sequence length increases, vanilla RNN performance degrades rapidly due to vanishing gradients, while LSTM maintains its ability to learn long-term dependencies.
Aspect
Vanilla RNN
LSTM
Parameters (hidden=128)
~16K
~66K (4x more gates)
Effective memory
~10-20 steps
~100-500+ steps
Training speed/step
Faster
Slower (4x computation)
Gradient flow
Exponential decay
Controlled by forget gate
Best use case
Short sequences, real-time
Long sequences, complex patterns
Practical Tips and Best Practices
Here are essential practices for training LSTM networks effectively:
1. Initialization
Parameter
Recommendation
Reason
Input weights
Xavier/Glorot uniform
Maintains variance through layers
Hidden weights
Orthogonal
Eigenvalues near 1, good gradient flow
Forget gate bias
1.0
Start with remembering by default
Other biases
0.0
Standard initialization
2. Regularization
Dropout: Apply to inputs and outputs, NOT recurrent connections. Use 0.2-0.5 depending on dataset size.
Weight decay: Small values (1e-5 to 1e-4) help prevent overfitting
Gradient clipping: Clip norm to 1.0-5.0 to prevent explosion
Early stopping: Monitor validation loss and stop when it stops improving
3. Learning Rate
Start with 1e-3 for Adam, 1.0 for SGD
Use learning rate warmup for the first 1000-5000 steps
Use ReduceLROnPlateau or cosine annealing for decay
4. Sequence Handling
Handling Variable-Length Sequences
🐍sequence_handling.py
Explanation(5)
Code(55)
15Track Original Lengths
Store the original length of each sequence before padding. This is needed for packing and for computing metrics correctly.
18Sort by Length
pack_padded_sequence requires sequences sorted by length (descending). We track the sort indices to unsort later.
23Padding
pad_sequence pads all sequences to the length of the longest. The default padding value is 0.
31Packing
pack_padded_sequence creates a PackedSequence that allows the LSTM to skip padded positions entirely, saving computation.
37Unpacking
After LSTM processing, unpack back to padded format. The output will have zeros at padded positions.
50 lines without explanation
1import torch
2from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
34defprepare_batch(sequences:list[torch.Tensor])->tuple[torch.Tensor, torch.Tensor]:5"""
6 Prepare a batch of variable-length sequences for LSTM.
78 Args:
9 sequences: List of tensors, each (seq_len_i, features)
1011 Returns:
12 padded: (batch, max_seq_len, features)
13 lengths: (batch,) original lengths
14 """15# Get lengths before padding16 lengths = torch.tensor([seq.size(0)for seq in sequences])1718# Sort by length (descending) for pack_padded_sequence19 sorted_indices = torch.argsort(lengths, descending=True)20 sorted_sequences =[sequences[i]for i in sorted_indices]21 sorted_lengths = lengths[sorted_indices]2223# Pad to same length24 padded = pad_sequence(sorted_sequences, batch_first=True)2526return padded, sorted_lengths, sorted_indices
272829deflstm_with_packing(lstm: nn.LSTM, x: torch.Tensor, lengths: torch.Tensor):30"""
31 Efficiently process variable-length sequences.
32 """33# Pack: skip padded positions during computation34 packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True)3536# Process37 packed_output,(h_n, c_n)= lstm(packed)3839# Unpack: restore padding for downstream operations40 output, _ = pad_packed_sequence(packed_output, batch_first=True)4142return output,(h_n, c_n)434445# Example with variable-length sequences46sequences =[47 torch.randn(15,64),# Length 1548 torch.randn(10,64),# Length 1049 torch.randn(22,64),# Length 2250 torch.randn(8,64),# Length 851]5253padded, lengths, indices = prepare_batch(sequences)54print(f"Padded shape: {padded.shape}")# (4, 22, 64) - padded to longest55print(f"Lengths: {lengths}")# tensor([22, 15, 10, 8])
Debugging LSTM Networks
Here are common issues and how to diagnose them:
1. Loss Not Decreasing
Symptom
Possible Cause
Solution
Loss flat from start
Learning rate too low
Increase LR by 10x
Loss oscillates wildly
Learning rate too high
Decrease LR, add gradient clipping
Loss decreases then plateaus early
Model capacity too low
Add layers or increase hidden size
NaN loss
Gradient explosion
Add gradient clipping, lower LR
2. Gradient Issues
Gradient Diagnostics
🐍debug_gradients.py
Explanation(3)
Code(34)
6Gradient Norm
The L2 norm of all gradients gives a single number summarizing gradient magnitude. Track this during training.
17Vanishing Detection
Gradient norm < 1e-6 indicates vanishing gradients. The model is learning very slowly or not at all.
21Exploding Detection
Gradient norm > 1000 indicates explosion. Training will be unstable and may produce NaN.
Out of Memory: Reduce batch size, sequence length, or hidden size. Consider gradient accumulation.
Slow Training: Use cuDNN (default with CUDA), enable mixed precision training.
Memory Leak: Detach hidden states when not backpropagating through full sequence.
Detach Hidden States for Long Sequences
When training on very long sequences or continuous text, you need to detach hidden states to prevent the computation graph from growing indefinitely:
Truncated Backpropagation Through Time
🐍truncated_bptt.py
Explanation(1)
Code(15)
12Detach Before Reuse
Without detaching, the computation graph grows with each batch. Memory usage will increase indefinitely until OOM.
14 lines without explanation
1defdetach_state(state):2"""Detach hidden states from computation graph."""3if state isNone:4returnNone5 h, c = state
6return(h.detach(), c.detach())78# Training loop with truncated BPTT9state =None10for batch in data_loader:11if state isnotNone:12 state = detach_state(state)# Crucial!13 output, state = model(batch, state)14 loss = criterion(output, target)15 loss.backward()
Summary
In this section, we implemented LSTM from scratch and learned how to train it effectively. Key takeaways:
Implementation Insights
Concept
Key Point
Practical Impact
Gate combining
4 gates in 1 matmul
3-4x speedup on GPU
Weight initialization
Orthogonal + forget bias=1
Faster convergence
Gradient clipping
Clip norm to 1-5
Prevents explosion
Sequence packing
Skip padded positions
Faster variable-length training
State detachment
Detach for long sequences
Prevents memory leak
Key Code Patterns
LSTM Cell:(xt,ht−1,Ct−1)→(ht,Ct)
LSTM Layer: Iterate cell over sequence, collect outputs
Efficient Gates:gates=xWihT+hWhhT+b, then chunk
Training: CrossEntropy + Adam + gradient clipping + LR scheduling
Looking Forward
In the next section, we'll study the Gated Recurrent Unit (GRU)—a simpler alternative to LSTM that often achieves similar performance with fewer parameters.
Knowledge Check
Test your understanding of LSTM implementation details:
Knowledge CheckQuestion 1 of 8
What is the main advantage of combining all 4 gate computations into a single matrix multiplication?
Score: 0/0
Exercises
Implementation Exercises
Bidirectional LSTM: Extend the LSTMLayer class to support bidirectional processing. Run the sequence through a forward and backward LSTM, then concatenate their outputs.
Stacked LSTM: Implement a multi-layer LSTM by stacking LSTMLayer instances. The output of layer i becomes the input to layer i+1.
Peephole Connections: Modify the LSTM cell to include peephole connections, where gates also look at the cell state: ft=σ(Wf[ht−1,xt]+Wcf⊙Ct−1+bf)
Layer Normalization: Add layer normalization to the LSTM cell. Normalize the gate activations before applying sigmoid/tanh.
Training Exercises
Shakespeare Generator: Train a character-level LSTM on Shakespeare text. Generate 500 characters after training and analyze the output.
Sentiment Analysis: Build an LSTM classifier for IMDB movie reviews. Compare performance with 1, 2, and 3 layers.
Sequence Length Study: Train LSTMs on the copy task with delays of 10, 50, 100, and 200 steps. Plot accuracy vs. delay length.
Analysis Exercises
Gate Visualization: For a trained LSTM, visualize the forget gate activations while processing a sentence. Which words trigger forgetting?
Cell State Dynamics: Track cell state values over time while processing text. Identify which dimensions encode what information.
Gradient Flow Comparison: Compare gradient norms at different timesteps for vanilla RNN vs. LSTM on a 100-step sequence.
Exercise Hints
Bidirectional: Create two separate LSTMLayer instances, reverse the sequence for the backward one
Shakespeare: Use sequences of 100-200 characters, hidden size 512, 2-3 layers
Gate visualization: Register forward hooks to capture gate activations during inference
Challenge Project
Build a Mini-Transformer from LSTM: Implement attention over LSTM hidden states to create a simple encoder-decoder model for translation:
Encoder: Bidirectional LSTM that processes source sentence
Attention: Compute attention weights over encoder states
Decoder: LSTM that generates target sentence using attention context
Train on a small parallel corpus (e.g., Multi30k English-German)
Now that you can implement and train LSTMs, you're ready to learn about the Gated Recurrent Unit (GRU)—a more recent architecture that simplifies LSTM while maintaining its key benefits. We'll compare the two architectures and discuss when to use each.