Learning Objectives
By the end of this section, you will be able to:
- Understand nn.Module as the foundational building block for neural networks in PyTorch
- Create custom modules by subclassing nn.Module with proper initialization
- Register parameters correctly using nn.Parameter and understand why plain tensors fail
- Implement the forward() method to define data flow through your network
- Navigate module hierarchies using children(), modules(), and named_parameters()
- Manage module state with train(), eval(), to(), and state_dict()
- Save and load models properly for deployment and checkpointing
Why This Matters: Every neural network in PyTorch - from a simple linear classifier to GPT-4 - is built using nn.Module. Understanding this class is essential for building, debugging, and deploying any deep learning model. It provides the infrastructure for parameter management, device placement, serialization, and the training/evaluation lifecycle.
The Big Picture
In the previous chapter, we learned about tensors, autograd, and how PyTorch automatically computes gradients. But managing hundreds of tensors manually would be chaos. We need a structured way to:
- Organize learnable parameters (weights and biases)
- Define how data flows through operations
- Move entire models between CPU and GPU
- Save and load trained models
- Switch between training and inference modes
The nn.Module class solves all of these problems. It's PyTorch's answer to the question: “How do we organize neural network components in a maintainable way?”
Historical Context
The design of nn.Module was influenced by earlier frameworks like Torch (Lua) and Theano, but with a key innovation: dynamic computation graphs. Unlike TensorFlow 1.x where you defined a static graph before execution, PyTorch builds the graph on-the-fly. This means nn.Module is just a regular Python class - you can use Python control flow (if, for, while) directly in your forward pass.
This Pythonic design philosophy made PyTorch the dominant framework for research, and nn.Module is at the heart of it all.
What is nn.Module?
torch.nn.Module is the base class for all neural network modules in PyTorch. Think of it as a container that:
- Holds learnable parameters: Weights and biases that get updated during training
- Defines computation: The forward() method specifies how inputs become outputs
- Organizes submodules: Complex networks are built from simpler modules in a tree structure
- Manages state: Training mode, device placement, and serialization
The Module Contract
When you create a class that inherits from nn.Module, you agree to follow certain conventions:
| Convention | What It Means | Why It Matters |
|---|---|---|
| Call super().__init__() | Initialize the parent class | Enables parameter tracking and hooks |
| Define forward() | Specify the computation | Called when you invoke module(input) |
| Use nn.Parameter for weights | Wrap learnable tensors | Makes them visible to optimizers |
| Assign submodules to self | Register child modules | Enables recursive operations |
Creating Your First Module
Let's create a simple neural network module step by step:
Always Call super().__init__()
super().__init__() is one of the most common bugs. Without it:- Parameters won't be tracked
- state_dict() will be empty
- to(device) won't move parameters
- Hooks won't work
Parameter Registration
Understanding how PyTorch tracks learnable parameters is crucial. There are three ways to store tensors in a module, and only one is correct for learnable weights:
Method 1: nn.Parameter (Correct for Learnable Weights)
1class MyLayer(nn.Module):
2 def __init__(self, in_features, out_features):
3 super().__init__()
4 # Correct: Use nn.Parameter for learnable weights
5 self.weight = nn.Parameter(torch.randn(out_features, in_features))
6 self.bias = nn.Parameter(torch.zeros(out_features))
7
8 def forward(self, x):
9 return x @ self.weight.T + self.bias
10
11# These are now tracked!
12layer = MyLayer(10, 5)
13print(list(layer.parameters())) # Shows weight and biasnn.Parameter is a special tensor wrapper that:
- Sets
requires_grad=Trueby default - Registers the tensor with the module's parameter tracking
- Makes the tensor appear in
parameters()andstate_dict()
Method 2: register_buffer (For Non-Trainable State)
1class BatchNormLike(nn.Module):
2 def __init__(self, num_features):
3 super().__init__()
4 # Trainable parameters
5 self.gamma = nn.Parameter(torch.ones(num_features))
6 self.beta = nn.Parameter(torch.zeros(num_features))
7
8 # Non-trainable but should be saved/loaded
9 self.register_buffer('running_mean', torch.zeros(num_features))
10 self.register_buffer('running_var', torch.ones(num_features))
11
12 def forward(self, x):
13 # Use running stats in eval mode
14 ...Use register_buffer() for tensors that:
- Should be saved/loaded with the model (in state_dict)
- Should move with the model (via .to(device))
- Should NOT be trained (not in parameters())
Common Buffer Examples
Method 3: Plain Tensor (Bug!)
1class BuggyLayer(nn.Module):
2 def __init__(self, in_features, out_features):
3 super().__init__()
4 # BUG: Plain tensor, not registered!
5 self.weight = torch.randn(out_features, in_features, requires_grad=True)
6
7# This weight is INVISIBLE to PyTorch!
8layer = BuggyLayer(10, 5)
9print(list(layer.parameters())) # Empty list!
10# Optimizer won't update this weight!Plain Tensors Are Not Tracked
- Won't appear in model.parameters()
- Won't be updated by the optimizer
- Won't be saved in state_dict()
- Won't move with .to(device)
Interactive: Parameter Registration Demo
Explore the difference between correct and incorrect parameter registration. See how mistakes lead to invisible parameters and models that fail to learn.
Parameter Registration Demo
The proper way to register learnable parameters
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# Correct: Use nn.Parameter
self.weight = nn.Parameter(
torch.randn(10, 5)
)nn.Parameter wraps the tensor and registers it as a parameter
# Check registered parameters list(model.parameters())
# Optimizer can find and update it
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01
)nn.Parameter or submodules (nn.Module) assigned as attributes are properly registered. Use register_buffer() for non-trainable state.The forward() Method
The forward() method is where you define the computation your module performs. It takes input tensors and returns output tensors, building the computational graph as it executes.
Key Principles
- Pure function-like: Given the same inputs and module state, produce the same outputs
- Use tensor operations: All operations should be differentiable (unless intentionally not)
- Support batching: Design for batch dimension (usually first dimension)
- Handle variable inputs: Can accept multiple tensors, return multiple tensors
Why model(x) Not model.forward(x)?
When you call model(x), PyTorch invokes __call__ which:
- Executes registered forward pre-hooks
- Calls your forward() method
- Executes registered forward hooks
- Handles any errors with better messages
1# Hooks only work when using model(x)
2def print_input_shape(module, input):
3 print(f"Input shape: {input[0].shape}")
4
5model.register_forward_pre_hook(print_input_shape)
6
7# Using model(x) - hook is called
8output = model(x) # Prints: Input shape: [32, 784]
9
10# Using model.forward(x) - hook is SKIPPED!
11output = model.forward(x) # No print - bypassed the hook!Interactive: Forward Pass Flow
Watch how data flows through a neural network layer by layer. Observe how each layer transforms the tensor shapes and see the corresponding code highlighted.
Forward Pass Flow Visualization
Data Flow
forward() Method
def forward(self, x):
# Input: [batch, 784]x = self.fc1(x)x = self.relu1(x)x = self.fc2(x)x = self.relu2(x)x = self.fc3(x)return x # [batch, 10]Current State
Press Play or Step to begin the forward pass
forward() method defines how data flows through your module. When you call model(x), PyTorch invokes forward(x) through the__call__ method, which also handles hooks.Module Hierarchy
Complex neural networks are built as trees of modules. When you assign an nn.Module as an attribute, it becomes a child module, and its parameters become part of the parent.
Navigation Methods
| Method | Returns | Recursive? | Use Case |
|---|---|---|---|
| children() | Direct child modules | No | Iterate immediate submodules |
| modules() | All modules (including self) | Yes | Apply operation to every module |
| named_children() | Pairs of (name, child) | No | Access children by name |
| named_modules() | Pairs of (name, module) | Yes | Full module tree with paths |
| parameters() | All parameters | Yes | Pass to optimizer |
| named_parameters() | Pairs of (name, param) | Yes | Inspect or modify specific params |
| buffers() | All buffers | Yes | Access non-trainable state |
Interactive: Module Hierarchy Explorer
Explore the hierarchical structure of different neural network architectures. See how modules contain submodules, and how parameters are organized throughout the tree.
Module Hierarchy Explorer
Model Statistics
Legend
nn.Module organizes submodules and parameters into a tree structure accessible via named_modules() and named_parameters().Essential Module Methods
Accessing Parameters
1model = MyNetwork()
2
3# All parameters (for optimizer)
4optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
5
6# Count parameters
7total = sum(p.numel() for p in model.parameters())
8trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
9print(f"Total: {total:,}, Trainable: {trainable:,}")
10
11# Access specific parameters by name
12for name, param in model.named_parameters():
13 if 'bias' in name:
14 print(f"{name}: {param.shape}")
15
16# Freeze specific layers
17for param in model.encoder.parameters():
18 param.requires_grad = FalseDevice Management
1# Move entire model to GPU
2model = model.to('cuda')
3
4# Or use .cuda() / .cpu() shortcuts
5model = model.cuda() # Same as .to('cuda')
6model = model.cpu() # Same as .to('cpu')
7
8# Chaining works (returns self)
9model = MyNetwork().to('cuda')
10
11# Check device of parameters
12device = next(model.parameters()).device
13print(f"Model is on: {device}")
14
15# Mixed precision
16model = model.to(dtype=torch.float16)
17model = model.half() # Shortcut for float16
18
19# Move to specific GPU
20model = model.to('cuda:1') # Second GPUApplying Functions to All Modules
1# Custom weight initialization
2def init_weights(module):
3 if isinstance(module, nn.Linear):
4 nn.init.xavier_uniform_(module.weight)
5 if module.bias is not None:
6 nn.init.zeros_(module.bias)
7 elif isinstance(module, nn.Conv2d):
8 nn.init.kaiming_normal_(module.weight, mode='fan_out')
9
10# Apply to all modules recursively
11model.apply(init_weights)
12
13# Check all module types
14for name, module in model.named_modules():
15 print(f"{name}: {type(module).__name__}")Module Lifecycle
Understanding the lifecycle of an nn.Module helps you use it correctly throughout training and deployment.
- Initialization: Define architecture in __init__(), register parameters and submodules
- Device placement: Move to GPU/CPU with .to() before training
- Training mode: Call .train() to enable dropout, use batch stats for BatchNorm
- Evaluation mode: Call .eval() for inference, use running stats for BatchNorm
- Persistence: Save/load with state_dict() and load_state_dict()
Interactive: Module Lifecycle
Walk through the complete lifecycle of an nn.Module from initialization to deployment. Understand what happens at each stage.
Module Lifecycle Explorer
1. __init__() - Define Architecture
Set up the module structure and register parameters/submodules
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)Key Points
- Call super().__init__() first
- Register submodules as attributes
- Parameters are tracked automatically
- Module is in training mode by default
Module State After This Step
.train() and.eval() affect behavior,.to() moves data, and state_dict() handles persistence.Training vs Evaluation Mode
Some layers behave differently during training and inference. The .train() and.eval() methods control this behavior.
| Layer | Training Mode (.train()) | Eval Mode (.eval()) |
|---|---|---|
| nn.Dropout | Randomly drops neurons | Identity (no dropout) |
| nn.BatchNorm | Uses batch statistics, updates running stats | Uses frozen running statistics |
| nn.InstanceNorm | Uses instance statistics (affine=True matters) | Same behavior |
| Custom layers | self.training is True | self.training is False |
1# During training
2model.train() # Sets self.training = True for all modules
3for batch in train_loader:
4 optimizer.zero_grad()
5 output = model(batch)
6 loss.backward()
7 optimizer.step()
8
9# During evaluation
10model.eval() # Sets self.training = False for all modules
11with torch.no_grad(): # Also disable gradient computation
12 for batch in test_loader:
13 output = model(batch)
14 # Compute metrics...
15
16# Custom behavior based on mode
17class MyLayer(nn.Module):
18 def forward(self, x):
19 if self.training:
20 # Training-specific behavior
21 x = x + torch.randn_like(x) * 0.1 # Add noise
22 return xDon't Confuse train()/eval() with no_grad()
.train()/.eval(): Affect layer behavior (Dropout, BatchNorm)torch.no_grad(): Disables gradient computation (saves memory)
model.eval() AND with torch.no_grad():Deep Dive: Layer-Specific Behavior
Saving and Loading Models
PyTorch provides two approaches to saving models:
Approach 1: Save State Dict (Recommended)
1# Save only the learned parameters
2torch.save(model.state_dict(), 'model_weights.pt')
3
4# Load into a new model instance
5model = MyNetwork() # Create architecture first
6model.load_state_dict(torch.load('model_weights.pt'))
7model.eval() # Set to eval mode for inference
8
9# With device mapping
10model.load_state_dict(
11 torch.load('model_weights.pt', map_location='cuda:0')
12)Approach 2: Save Entire Model (Less Flexible)
1# Save entire model (architecture + weights)
2torch.save(model, 'full_model.pt')
3
4# Load (requires original class definition to be importable)
5model = torch.load('full_model.pt')
6model.eval()
7
8# Warning: This pickles the class, so:
9# - Code changes break compatibility
10# - Path to class definition must be validBest Practice: State Dict + Architecture
Checkpointing During Training
1# Save everything needed to resume training
2checkpoint = {
3 'epoch': epoch,
4 'model_state_dict': model.state_dict(),
5 'optimizer_state_dict': optimizer.state_dict(),
6 'loss': loss,
7 'best_accuracy': best_acc,
8}
9torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')
10
11# Resume training
12checkpoint = torch.load('checkpoint_epoch_10.pt')
13model.load_state_dict(checkpoint['model_state_dict'])
14optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
15start_epoch = checkpoint['epoch'] + 1Common Patterns
Pattern 1: Encoder-Decoder
1class AutoEncoder(nn.Module):
2 def __init__(self, input_dim, latent_dim):
3 super().__init__()
4 self.encoder = nn.Sequential(
5 nn.Linear(input_dim, 256),
6 nn.ReLU(),
7 nn.Linear(256, latent_dim),
8 )
9 self.decoder = nn.Sequential(
10 nn.Linear(latent_dim, 256),
11 nn.ReLU(),
12 nn.Linear(256, input_dim),
13 )
14
15 def forward(self, x):
16 z = self.encoder(x)
17 return self.decoder(z)
18
19 def encode(self, x):
20 return self.encoder(x)Pattern 2: ModuleList for Dynamic Architectures
1class DynamicMLP(nn.Module):
2 def __init__(self, layer_sizes):
3 super().__init__()
4 # ModuleList properly registers all modules
5 self.layers = nn.ModuleList([
6 nn.Linear(layer_sizes[i], layer_sizes[i+1])
7 for i in range(len(layer_sizes) - 1)
8 ])
9
10 def forward(self, x):
11 for layer in self.layers[:-1]:
12 x = torch.relu(layer(x))
13 return self.layers[-1](x) # No activation on last layer
14
15# Flexible architecture
16model = DynamicMLP([784, 512, 256, 128, 10])Pattern 3: ModuleDict for Named Components
1class MultiTaskNetwork(nn.Module):
2 def __init__(self, shared_dim, task_configs):
3 super().__init__()
4 self.shared = nn.Linear(784, shared_dim)
5
6 # ModuleDict for named task heads
7 self.heads = nn.ModuleDict({
8 name: nn.Linear(shared_dim, out_dim)
9 for name, out_dim in task_configs.items()
10 })
11
12 def forward(self, x, task_name):
13 shared = torch.relu(self.shared(x))
14 return self.heads[task_name](shared)
15
16model = MultiTaskNetwork(256, {'classify': 10, 'regress': 1})Quick Check
Why must you use nn.ModuleList instead of a regular Python list for storing layers?
Common Pitfalls
Pitfall 1: Forgetting super().__init__()
1class BrokenModule(nn.Module):
2 def __init__(self):
3 # BUG: Forgot super().__init__()
4 self.fc = nn.Linear(10, 5) # This will fail!
5
6# AttributeError: cannot assign module before Module.__init__() callPitfall 2: Using Regular List/Dict for Modules
1class BrokenNetwork(nn.Module):
2 def __init__(self):
3 super().__init__()
4 # BUG: Regular list doesn't register modules!
5 self.layers = [nn.Linear(10, 10) for _ in range(5)]
6
7model = BrokenNetwork()
8print(list(model.parameters())) # Empty! Layers are invisible.
9
10# FIX: Use nn.ModuleList
11self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])Pitfall 3: Calling forward() Directly
1# DON'T do this
2output = model.forward(x) # Bypasses hooks!
3
4# DO this
5output = model(x) # Correct - uses __call__Pitfall 4: Forgetting eval() for Inference
1# BUG: Model still in training mode during inference
2model.load_state_dict(torch.load('model.pt'))
3predictions = model(test_data) # Dropout still active!
4
5# FIX: Always call eval() for inference
6model.load_state_dict(torch.load('model.pt'))
7model.eval()
8with torch.no_grad():
9 predictions = model(test_data)Pitfall 5: Device Mismatch
1# BUG: Model on GPU, input on CPU
2model = model.cuda()
3x = torch.randn(32, 784) # CPU tensor
4output = model(x) # RuntimeError: tensors on different devices!
5
6# FIX: Move inputs to same device as model
7device = next(model.parameters()).device
8x = x.to(device)
9output = model(x) # Works!Knowledge Check
Test your understanding of nn.Module with this comprehensive quiz covering all the key concepts.
nn.Module Knowledge Check
What is the correct way to register a learnable parameter in an nn.Module?
Summary
In this section, we learned about nn.Module, the foundation of all neural networks in PyTorch:
| Concept | Key Points |
|---|---|
| nn.Module | Base class for all neural network components in PyTorch |
| __init__() | Define architecture, ALWAYS call super().__init__() first |
| forward() | Define computation, called via model(x) not model.forward(x) |
| nn.Parameter | Wrapper for learnable tensors, visible to parameters() |
| register_buffer() | For non-trainable tensors that should be saved/moved |
| Submodules | nn.Module attributes are auto-registered as children |
| ModuleList/Dict | Use instead of Python list/dict for proper registration |
| train()/eval() | Toggle layer behavior (Dropout, BatchNorm) |
| to(device) | Move all parameters and buffers to device |
| state_dict() | Serialize parameters and buffers for saving/loading |
Key Takeaways
- nn.Module is the foundation: Every neural network layer, block, and model is an nn.Module.
- Proper registration is critical: Use nn.Parameter for weights, register_buffer() for non-trainable state, and nn.ModuleList/Dict for collections.
- Call model(x), not model.forward(x): The __call__ method handles hooks and other infrastructure.
- Manage state carefully: Remember train()/eval() for layers like Dropout and BatchNorm, and always match devices.
- Save state_dict(), not the model: This makes your checkpoints portable and maintainable.
Exercises
Conceptual Questions
- Explain the difference between
nn.Parameterandregister_buffer(). When would you use each? - Why does PyTorch automatically register nn.Module attributes as submodules, but not regular tensors as parameters?
- Describe a scenario where forgetting to call
model.eval()would cause incorrect predictions even though the model was trained correctly.
Coding Exercises
- Custom Layer: Implement a
LayerNormlayer from scratch using nn.Parameter for gamma and beta. The forward should normalize across the last dimension: - Residual Block: Create a ResidualBlock module that wraps any layer and adds a skip connection: . Handle the case where input and output dimensions don't match.
- Model Inspector: Write a function that takes any nn.Module and prints: total parameters, trainable parameters, number of layers, and memory usage in MB.
- Freeze/Unfreeze: Implement a function that freezes all layers except those whose names match a given pattern (e.g., freeze everything except
classifier.*).
Challenge Exercise
Build a Mini nn.Module: Implement a simplified version of nn.Module from scratch that supports:
- Parameter registration via descriptors or __setattr__
- Submodule registration
- Recursive parameters() and named_parameters()
- train() and eval() mode switching
- to() for device placement
- state_dict() and load_state_dict()
1class MiniModule:
2 """Simplified nn.Module implementation."""
3
4 def __init__(self):
5 self._parameters = {}
6 self._modules = {}
7 self._buffers = {}
8 self.training = True
9
10 def __setattr__(self, name, value):
11 # TODO: Detect nn.Parameter and Module instances
12 # Register them appropriately
13 ...
14
15 def parameters(self):
16 # TODO: Yield all parameters recursively
17 ...
18
19 def to(self, device):
20 # TODO: Move all parameters and buffers
21 ...
22
23 def state_dict(self):
24 # TODO: Return dict of all parameters and buffers
25 ...
26
27# Test your implementation
28class LinearLayer(MiniModule):
29 def __init__(self, in_features, out_features):
30 super().__init__()
31 self.weight = Parameter(torch.randn(out_features, in_features))
32 self.bias = Parameter(torch.zeros(out_features))
33
34 def forward(self, x):
35 return x @ self.weight.T + self.biasIn the next section, we'll explore Linear Layers - the most fundamental building block of neural networks. We'll understand the mathematics of linear transformations and see how nn.Linear implements them efficiently.