Chapter 5
20 min read
Section 31 of 178

The nn.Module Class

Neural Network Building Blocks

Learning Objectives

By the end of this section, you will be able to:

  1. Understand nn.Module as the foundational building block for neural networks in PyTorch
  2. Create custom modules by subclassing nn.Module with proper initialization
  3. Register parameters correctly using nn.Parameter and understand why plain tensors fail
  4. Implement the forward() method to define data flow through your network
  5. Navigate module hierarchies using children(), modules(), and named_parameters()
  6. Manage module state with train(), eval(), to(), and state_dict()
  7. Save and load models properly for deployment and checkpointing
Why This Matters: Every neural network in PyTorch - from a simple linear classifier to GPT-4 - is built using nn.Module. Understanding this class is essential for building, debugging, and deploying any deep learning model. It provides the infrastructure for parameter management, device placement, serialization, and the training/evaluation lifecycle.

The Big Picture

In the previous chapter, we learned about tensors, autograd, and how PyTorch automatically computes gradients. But managing hundreds of tensors manually would be chaos. We need a structured way to:

  • Organize learnable parameters (weights and biases)
  • Define how data flows through operations
  • Move entire models between CPU and GPU
  • Save and load trained models
  • Switch between training and inference modes

The nn.Module class solves all of these problems. It's PyTorch's answer to the question: “How do we organize neural network components in a maintainable way?”

Historical Context

The design of nn.Module was influenced by earlier frameworks like Torch (Lua) and Theano, but with a key innovation: dynamic computation graphs. Unlike TensorFlow 1.x where you defined a static graph before execution, PyTorch builds the graph on-the-fly. This means nn.Module is just a regular Python class - you can use Python control flow (if, for, while) directly in your forward pass.

This Pythonic design philosophy made PyTorch the dominant framework for research, and nn.Module is at the heart of it all.


What is nn.Module?

torch.nn.Module is the base class for all neural network modules in PyTorch. Think of it as a container that:

  1. Holds learnable parameters: Weights and biases that get updated during training
  2. Defines computation: The forward() method specifies how inputs become outputs
  3. Organizes submodules: Complex networks are built from simpler modules in a tree structure
  4. Manages state: Training mode, device placement, and serialization

The Module Contract

When you create a class that inherits from nn.Module, you agree to follow certain conventions:

ConventionWhat It MeansWhy It Matters
Call super().__init__()Initialize the parent classEnables parameter tracking and hooks
Define forward()Specify the computationCalled when you invoke module(input)
Use nn.Parameter for weightsWrap learnable tensorsMakes them visible to optimizers
Assign submodules to selfRegister child modulesEnables recursive operations

Creating Your First Module

Let's create a simple neural network module step by step:

Creating a Custom nn.Module
🐍simple_classifier.py
8super().__init__() - Critical First Step

This initializes the internal machinery of nn.Module: parameter tracking, hooks, module registry, and more. Forgetting this breaks everything!

13Submodule Registration

When you assign an nn.Module to self, PyTorch automatically registers it as a submodule. Its parameters become part of your module's parameters().

18The forward() Method

This method defines the computation. It takes input tensors and returns output tensors. PyTorch builds the computational graph as this executes.

28Calling the Module

Use model(x), not model.forward(x). The __call__ method handles hooks and then calls forward(). Direct forward() calls bypass important functionality.

27 lines without explanation
1import torch
2import torch.nn as nn
3
4class SimpleClassifier(nn.Module):
5    """A simple two-layer neural network classifier."""
6
7    def __init__(self, input_dim, hidden_dim, num_classes):
8        # Step 1: ALWAYS call parent __init__ first
9        super().__init__()
10
11        # Step 2: Define layers as attributes
12        # These are automatically registered as submodules
13        self.fc1 = nn.Linear(input_dim, hidden_dim)
14        self.relu = nn.ReLU()
15        self.fc2 = nn.Linear(hidden_dim, num_classes)
16
17    def forward(self, x):
18        """Define the forward pass."""
19        # Step 3: Specify how data flows through layers
20        x = self.fc1(x)      # Linear transformation
21        x = self.relu(x)     # Non-linear activation
22        x = self.fc2(x)      # Output layer
23        return x
24
25# Create an instance
26model = SimpleClassifier(input_dim=784, hidden_dim=256, num_classes=10)
27
28# Use it! This calls forward() internally
29x = torch.randn(32, 784)  # Batch of 32 samples
30output = model(x)         # Shape: [32, 10]
31print(output.shape)

Always Call super().__init__()

Forgetting super().__init__() is one of the most common bugs. Without it:
  • Parameters won't be tracked
  • state_dict() will be empty
  • to(device) won't move parameters
  • Hooks won't work
PyTorch will raise errors when you try to use the module.

Parameter Registration

Understanding how PyTorch tracks learnable parameters is crucial. There are three ways to store tensors in a module, and only one is correct for learnable weights:

Method 1: nn.Parameter (Correct for Learnable Weights)

🐍parameter_registration.py
1class MyLayer(nn.Module):
2    def __init__(self, in_features, out_features):
3        super().__init__()
4        # Correct: Use nn.Parameter for learnable weights
5        self.weight = nn.Parameter(torch.randn(out_features, in_features))
6        self.bias = nn.Parameter(torch.zeros(out_features))
7
8    def forward(self, x):
9        return x @ self.weight.T + self.bias
10
11# These are now tracked!
12layer = MyLayer(10, 5)
13print(list(layer.parameters()))  # Shows weight and bias

nn.Parameter is a special tensor wrapper that:

  • Sets requires_grad=True by default
  • Registers the tensor with the module's parameter tracking
  • Makes the tensor appear in parameters() and state_dict()

Method 2: register_buffer (For Non-Trainable State)

🐍register_buffer.py
1class BatchNormLike(nn.Module):
2    def __init__(self, num_features):
3        super().__init__()
4        # Trainable parameters
5        self.gamma = nn.Parameter(torch.ones(num_features))
6        self.beta = nn.Parameter(torch.zeros(num_features))
7
8        # Non-trainable but should be saved/loaded
9        self.register_buffer('running_mean', torch.zeros(num_features))
10        self.register_buffer('running_var', torch.ones(num_features))
11
12    def forward(self, x):
13        # Use running stats in eval mode
14        ...

Use register_buffer() for tensors that:

  • Should be saved/loaded with the model (in state_dict)
  • Should move with the model (via .to(device))
  • Should NOT be trained (not in parameters())

Common Buffer Examples

Buffers are used for: running statistics (BatchNorm), positional encodings (Transformer), attention masks, frozen embeddings, and any constant tensors that need to move with the model.

Method 3: Plain Tensor (Bug!)

🐍plain_tensor_bug.py
1class BuggyLayer(nn.Module):
2    def __init__(self, in_features, out_features):
3        super().__init__()
4        # BUG: Plain tensor, not registered!
5        self.weight = torch.randn(out_features, in_features, requires_grad=True)
6
7# This weight is INVISIBLE to PyTorch!
8layer = BuggyLayer(10, 5)
9print(list(layer.parameters()))  # Empty list!
10# Optimizer won't update this weight!

Plain Tensors Are Not Tracked

If you use a plain tensor instead of nn.Parameter, the tensor:
  • Won't appear in model.parameters()
  • Won't be updated by the optimizer
  • Won't be saved in state_dict()
  • Won't move with .to(device)
Your model will appear to train but learn nothing!

Interactive: Parameter Registration Demo

Explore the difference between correct and incorrect parameter registration. See how mistakes lead to invisible parameters and models that fail to learn.

Parameter Registration Demo

Correct: Using nn.Parameter

The proper way to register learnable parameters

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # Correct: Use nn.Parameter
        self.weight = nn.Parameter(
            torch.randn(10, 5)
        )

nn.Parameter wraps the tensor and registers it as a parameter

Parameter registered with requires_grad=True
# Check registered parameters
list(model.parameters())
# Optimizer can find and update it
optimizer = torch.optim.SGD(
    model.parameters(), lr=0.01
)
Key Insight: Only tensors wrapped in nn.Parameter or submodules (nn.Module) assigned as attributes are properly registered. Use register_buffer() for non-trainable state.

The forward() Method

The forward() method is where you define the computation your module performs. It takes input tensors and returns output tensors, building the computational graph as it executes.

Key Principles

  1. Pure function-like: Given the same inputs and module state, produce the same outputs
  2. Use tensor operations: All operations should be differentiable (unless intentionally not)
  3. Support batching: Design for batch dimension (usually first dimension)
  4. Handle variable inputs: Can accept multiple tensors, return multiple tensors
Forward with Multiple Inputs
🐍multi_input.py
8Multiple Input Parameters

forward() can accept any number of positional or keyword arguments. Design the signature to match your use case.

15Python Control Flow Works

You can use if statements, loops, and any Python logic. The graph is built dynamically as code executes.

25Calling with Multiple Inputs

Pass multiple tensors matching the forward() signature. PyTorch handles gradient tracking for all.

23 lines without explanation
1class MultiInputNetwork(nn.Module):
2    def __init__(self, text_dim, image_dim, hidden_dim, num_classes):
3        super().__init__()
4        self.text_encoder = nn.Linear(text_dim, hidden_dim)
5        self.image_encoder = nn.Linear(image_dim, hidden_dim)
6        self.classifier = nn.Linear(hidden_dim * 2, num_classes)
7
8    def forward(self, text_features, image_features):
9        """Forward can accept multiple inputs."""
10        # Process each modality
11        text_hidden = torch.relu(self.text_encoder(text_features))
12        image_hidden = torch.relu(self.image_encoder(image_features))
13
14        # Combine features
15        combined = torch.cat([text_hidden, image_hidden], dim=-1)
16
17        # Classify
18        logits = self.classifier(combined)
19
20        return logits
21
22# Using the model
23model = MultiInputNetwork(768, 2048, 256, 100)
24text = torch.randn(32, 768)
25image = torch.randn(32, 2048)
26output = model(text, image)  # Both inputs passed

Why model(x) Not model.forward(x)?

When you call model(x), PyTorch invokes __call__ which:

  1. Executes registered forward pre-hooks
  2. Calls your forward() method
  3. Executes registered forward hooks
  4. Handles any errors with better messages
🐍hooks_example.py
1# Hooks only work when using model(x)
2def print_input_shape(module, input):
3    print(f"Input shape: {input[0].shape}")
4
5model.register_forward_pre_hook(print_input_shape)
6
7# Using model(x) - hook is called
8output = model(x)  # Prints: Input shape: [32, 784]
9
10# Using model.forward(x) - hook is SKIPPED!
11output = model.forward(x)  # No print - bypassed the hook!

Interactive: Forward Pass Flow

Watch how data flows through a neural network layer by layer. Observe how each layer transforms the tensor shapes and see the corresponding code highlighted.

Forward Pass Flow Visualization

Data Flow

input(Tensor)
fc1(nn.Linear)
relu1(nn.ReLU)
fc2(nn.Linear)
relu2(nn.ReLU)
fc3(nn.Linear)
output(Logits)

forward() Method

def forward(self, x):
# Input: [batch, 784]x = self.fc1(x)x = self.relu1(x)x = self.fc2(x)x = self.relu2(x)x = self.fc3(x)return x # [batch, 10]
Current State

Press Play or Step to begin the forward pass

Key Point: The forward() method defines how data flows through your module. When you call model(x), PyTorch invokes forward(x) through the__call__ method, which also handles hooks.

Module Hierarchy

Complex neural networks are built as trees of modules. When you assign an nn.Module as an attribute, it becomes a child module, and its parameters become part of the parent.

Building Module Hierarchies
🐍module_hierarchy.py
15Nested Modules

ResNet contains ResidualBlocks, which contain Conv2d and BatchNorm. This creates a tree structure.

21nn.Sequential for Lists

nn.Sequential is a container that runs modules in order. It's perfect for stacking layers.

36named_children()

Iterates over direct children only (one level). Use named_modules() for recursive iteration.

40parameters() is Recursive

parameters() returns ALL parameters from ALL nested submodules. That's why one call gives you everything.

38 lines without explanation
1class ResidualBlock(nn.Module):
2    def __init__(self, channels):
3        super().__init__()
4        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
5        self.bn1 = nn.BatchNorm2d(channels)
6        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
7        self.bn2 = nn.BatchNorm2d(channels)
8
9    def forward(self, x):
10        residual = x
11        out = torch.relu(self.bn1(self.conv1(x)))
12        out = self.bn2(self.conv2(out))
13        return torch.relu(out + residual)
14
15class ResNet(nn.Module):
16    def __init__(self, num_blocks, num_classes):
17        super().__init__()
18        self.stem = nn.Conv2d(3, 64, 7, stride=2, padding=3)
19
20        # Create multiple residual blocks
21        self.blocks = nn.Sequential(
22            *[ResidualBlock(64) for _ in range(num_blocks)]
23        )
24
25        self.classifier = nn.Linear(64, num_classes)
26
27    def forward(self, x):
28        x = self.stem(x)
29        x = self.blocks(x)
30        x = x.mean(dim=[2, 3])  # Global average pooling
31        return self.classifier(x)
32
33# Explore the hierarchy
34model = ResNet(num_blocks=3, num_classes=10)
35
36# All child modules
37for name, module in model.named_children():
38    print(f"{name}: {type(module).__name__}")
39
40# All parameters (recursive)
41total_params = sum(p.numel() for p in model.parameters())
42print(f"Total parameters: {total_params:,}")
MethodReturnsRecursive?Use Case
children()Direct child modulesNoIterate immediate submodules
modules()All modules (including self)YesApply operation to every module
named_children()Pairs of (name, child)NoAccess children by name
named_modules()Pairs of (name, module)YesFull module tree with paths
parameters()All parametersYesPass to optimizer
named_parameters()Pairs of (name, param)YesInspect or modify specific params
buffers()All buffersYesAccess non-trainable state

Interactive: Module Hierarchy Explorer

Explore the hierarchical structure of different neural network architectures. See how modules contain submodules, and how parameters are organized throughout the tree.

Module Hierarchy Explorer

SimpleNet(nn.Module)50.9K params
fc1(nn.Linear)50.2K params
relu(nn.ReLU)
fc2(nn.Linear)650 params

Model Statistics

Total Parameters:50.9K
Trainable:50.9K
Non-trainable:0

Legend

Container module
Layer with parameters
Activation (no params)
Trainable parameter
Buffer (non-trainable)
Tip: Click on modules to expand/collapse their contents. The hierarchy shows how nn.Module organizes submodules and parameters into a tree structure accessible via named_modules() and named_parameters().

Essential Module Methods

Accessing Parameters

🐍accessing_params.py
1model = MyNetwork()
2
3# All parameters (for optimizer)
4optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
5
6# Count parameters
7total = sum(p.numel() for p in model.parameters())
8trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
9print(f"Total: {total:,}, Trainable: {trainable:,}")
10
11# Access specific parameters by name
12for name, param in model.named_parameters():
13    if 'bias' in name:
14        print(f"{name}: {param.shape}")
15
16# Freeze specific layers
17for param in model.encoder.parameters():
18    param.requires_grad = False

Device Management

🐍device_management.py
1# Move entire model to GPU
2model = model.to('cuda')
3
4# Or use .cuda() / .cpu() shortcuts
5model = model.cuda()  # Same as .to('cuda')
6model = model.cpu()   # Same as .to('cpu')
7
8# Chaining works (returns self)
9model = MyNetwork().to('cuda')
10
11# Check device of parameters
12device = next(model.parameters()).device
13print(f"Model is on: {device}")
14
15# Mixed precision
16model = model.to(dtype=torch.float16)
17model = model.half()  # Shortcut for float16
18
19# Move to specific GPU
20model = model.to('cuda:1')  # Second GPU

Applying Functions to All Modules

🐍apply_function.py
1# Custom weight initialization
2def init_weights(module):
3    if isinstance(module, nn.Linear):
4        nn.init.xavier_uniform_(module.weight)
5        if module.bias is not None:
6            nn.init.zeros_(module.bias)
7    elif isinstance(module, nn.Conv2d):
8        nn.init.kaiming_normal_(module.weight, mode='fan_out')
9
10# Apply to all modules recursively
11model.apply(init_weights)
12
13# Check all module types
14for name, module in model.named_modules():
15    print(f"{name}: {type(module).__name__}")

Module Lifecycle

Understanding the lifecycle of an nn.Module helps you use it correctly throughout training and deployment.

  1. Initialization: Define architecture in __init__(), register parameters and submodules
  2. Device placement: Move to GPU/CPU with .to() before training
  3. Training mode: Call .train() to enable dropout, use batch stats for BatchNorm
  4. Evaluation mode: Call .eval() for inference, use running stats for BatchNorm
  5. Persistence: Save/load with state_dict() and load_state_dict()

Interactive: Module Lifecycle

Walk through the complete lifecycle of an nn.Module from initialization to deployment. Understand what happens at each stage.

Module Lifecycle Explorer

1. __init__() - Define Architecture

Set up the module structure and register parameters/submodules

def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(784, 256)
    self.fc2 = nn.Linear(256, 10)
Key Points
  • Call super().__init__() first
  • Register submodules as attributes
  • Parameters are tracked automatically
  • Module is in training mode by default
Module State After This Step
training:True
parameters:2 Linear layers registered
device:CPU (default)
Progress0/5
Remember: The lifecycle methods work together..train() and.eval() affect behavior,.to() moves data, and state_dict() handles persistence.

Training vs Evaluation Mode

Some layers behave differently during training and inference. The .train() and.eval() methods control this behavior.

LayerTraining Mode (.train())Eval Mode (.eval())
nn.DropoutRandomly drops neuronsIdentity (no dropout)
nn.BatchNormUses batch statistics, updates running statsUses frozen running statistics
nn.InstanceNormUses instance statistics (affine=True matters)Same behavior
Custom layersself.training is Trueself.training is False
🐍train_eval_modes.py
1# During training
2model.train()  # Sets self.training = True for all modules
3for batch in train_loader:
4    optimizer.zero_grad()
5    output = model(batch)
6    loss.backward()
7    optimizer.step()
8
9# During evaluation
10model.eval()  # Sets self.training = False for all modules
11with torch.no_grad():  # Also disable gradient computation
12    for batch in test_loader:
13        output = model(batch)
14        # Compute metrics...
15
16# Custom behavior based on mode
17class MyLayer(nn.Module):
18    def forward(self, x):
19        if self.training:
20            # Training-specific behavior
21            x = x + torch.randn_like(x) * 0.1  # Add noise
22        return x

Don't Confuse train()/eval() with no_grad()

  • .train() / .eval(): Affect layer behavior (Dropout, BatchNorm)
  • torch.no_grad(): Disables gradient computation (saves memory)
For inference, you typically want BOTH: model.eval() AND with torch.no_grad():

Deep Dive: Layer-Specific Behavior

We'll explore the train/eval mode differences in detail for specific layers: BatchNorm and other normalization layers in Section 5.5, and Dropout and regularization techniques in Section 5.6.

Saving and Loading Models

PyTorch provides two approaches to saving models:

🐍save_state_dict.py
1# Save only the learned parameters
2torch.save(model.state_dict(), 'model_weights.pt')
3
4# Load into a new model instance
5model = MyNetwork()  # Create architecture first
6model.load_state_dict(torch.load('model_weights.pt'))
7model.eval()  # Set to eval mode for inference
8
9# With device mapping
10model.load_state_dict(
11    torch.load('model_weights.pt', map_location='cuda:0')
12)

Approach 2: Save Entire Model (Less Flexible)

🐍save_entire_model.py
1# Save entire model (architecture + weights)
2torch.save(model, 'full_model.pt')
3
4# Load (requires original class definition to be importable)
5model = torch.load('full_model.pt')
6model.eval()
7
8# Warning: This pickles the class, so:
9# - Code changes break compatibility
10# - Path to class definition must be valid

Best Practice: State Dict + Architecture

Always save state_dict() rather than the entire model. Save architecture separately (as code or config file). This makes your checkpoints portable and version-independent.

Checkpointing During Training

🐍checkpointing.py
1# Save everything needed to resume training
2checkpoint = {
3    'epoch': epoch,
4    'model_state_dict': model.state_dict(),
5    'optimizer_state_dict': optimizer.state_dict(),
6    'loss': loss,
7    'best_accuracy': best_acc,
8}
9torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')
10
11# Resume training
12checkpoint = torch.load('checkpoint_epoch_10.pt')
13model.load_state_dict(checkpoint['model_state_dict'])
14optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
15start_epoch = checkpoint['epoch'] + 1

Common Patterns

Pattern 1: Encoder-Decoder

🐍encoder_decoder.py
1class AutoEncoder(nn.Module):
2    def __init__(self, input_dim, latent_dim):
3        super().__init__()
4        self.encoder = nn.Sequential(
5            nn.Linear(input_dim, 256),
6            nn.ReLU(),
7            nn.Linear(256, latent_dim),
8        )
9        self.decoder = nn.Sequential(
10            nn.Linear(latent_dim, 256),
11            nn.ReLU(),
12            nn.Linear(256, input_dim),
13        )
14
15    def forward(self, x):
16        z = self.encoder(x)
17        return self.decoder(z)
18
19    def encode(self, x):
20        return self.encoder(x)

Pattern 2: ModuleList for Dynamic Architectures

🐍modulelist.py
1class DynamicMLP(nn.Module):
2    def __init__(self, layer_sizes):
3        super().__init__()
4        # ModuleList properly registers all modules
5        self.layers = nn.ModuleList([
6            nn.Linear(layer_sizes[i], layer_sizes[i+1])
7            for i in range(len(layer_sizes) - 1)
8        ])
9
10    def forward(self, x):
11        for layer in self.layers[:-1]:
12            x = torch.relu(layer(x))
13        return self.layers[-1](x)  # No activation on last layer
14
15# Flexible architecture
16model = DynamicMLP([784, 512, 256, 128, 10])

Pattern 3: ModuleDict for Named Components

🐍moduledict.py
1class MultiTaskNetwork(nn.Module):
2    def __init__(self, shared_dim, task_configs):
3        super().__init__()
4        self.shared = nn.Linear(784, shared_dim)
5
6        # ModuleDict for named task heads
7        self.heads = nn.ModuleDict({
8            name: nn.Linear(shared_dim, out_dim)
9            for name, out_dim in task_configs.items()
10        })
11
12    def forward(self, x, task_name):
13        shared = torch.relu(self.shared(x))
14        return self.heads[task_name](shared)
15
16model = MultiTaskNetwork(256, {'classify': 10, 'regress': 1})

Quick Check

Why must you use nn.ModuleList instead of a regular Python list for storing layers?


Common Pitfalls

Pitfall 1: Forgetting super().__init__()

🐍pitfall_super.py
1class BrokenModule(nn.Module):
2    def __init__(self):
3        # BUG: Forgot super().__init__()
4        self.fc = nn.Linear(10, 5)  # This will fail!
5
6# AttributeError: cannot assign module before Module.__init__() call

Pitfall 2: Using Regular List/Dict for Modules

🐍pitfall_list.py
1class BrokenNetwork(nn.Module):
2    def __init__(self):
3        super().__init__()
4        # BUG: Regular list doesn't register modules!
5        self.layers = [nn.Linear(10, 10) for _ in range(5)]
6
7model = BrokenNetwork()
8print(list(model.parameters()))  # Empty! Layers are invisible.
9
10# FIX: Use nn.ModuleList
11self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])

Pitfall 3: Calling forward() Directly

🐍pitfall_forward.py
1# DON'T do this
2output = model.forward(x)  # Bypasses hooks!
3
4# DO this
5output = model(x)  # Correct - uses __call__

Pitfall 4: Forgetting eval() for Inference

🐍pitfall_eval.py
1# BUG: Model still in training mode during inference
2model.load_state_dict(torch.load('model.pt'))
3predictions = model(test_data)  # Dropout still active!
4
5# FIX: Always call eval() for inference
6model.load_state_dict(torch.load('model.pt'))
7model.eval()
8with torch.no_grad():
9    predictions = model(test_data)

Pitfall 5: Device Mismatch

🐍pitfall_device.py
1# BUG: Model on GPU, input on CPU
2model = model.cuda()
3x = torch.randn(32, 784)  # CPU tensor
4output = model(x)  # RuntimeError: tensors on different devices!
5
6# FIX: Move inputs to same device as model
7device = next(model.parameters()).device
8x = x.to(device)
9output = model(x)  # Works!

Knowledge Check

Test your understanding of nn.Module with this comprehensive quiz covering all the key concepts.

nn.Module Knowledge Check

Score: 0/0
Question 1 of 10

What is the correct way to register a learnable parameter in an nn.Module?


Summary

In this section, we learned about nn.Module, the foundation of all neural networks in PyTorch:

ConceptKey Points
nn.ModuleBase class for all neural network components in PyTorch
__init__()Define architecture, ALWAYS call super().__init__() first
forward()Define computation, called via model(x) not model.forward(x)
nn.ParameterWrapper for learnable tensors, visible to parameters()
register_buffer()For non-trainable tensors that should be saved/moved
Submodulesnn.Module attributes are auto-registered as children
ModuleList/DictUse instead of Python list/dict for proper registration
train()/eval()Toggle layer behavior (Dropout, BatchNorm)
to(device)Move all parameters and buffers to device
state_dict()Serialize parameters and buffers for saving/loading

Key Takeaways

  1. nn.Module is the foundation: Every neural network layer, block, and model is an nn.Module.
  2. Proper registration is critical: Use nn.Parameter for weights, register_buffer() for non-trainable state, and nn.ModuleList/Dict for collections.
  3. Call model(x), not model.forward(x): The __call__ method handles hooks and other infrastructure.
  4. Manage state carefully: Remember train()/eval() for layers like Dropout and BatchNorm, and always match devices.
  5. Save state_dict(), not the model: This makes your checkpoints portable and maintainable.

Exercises

Conceptual Questions

  1. Explain the difference between nn.Parameter and register_buffer(). When would you use each?
  2. Why does PyTorch automatically register nn.Module attributes as submodules, but not regular tensors as parameters?
  3. Describe a scenario where forgetting to call model.eval() would cause incorrect predictions even though the model was trained correctly.

Coding Exercises

  1. Custom Layer: Implement a LayerNorm layer from scratch using nn.Parameter for gamma and beta. The forward should normalize across the last dimension:
    LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
  2. Residual Block: Create a ResidualBlock module that wraps any layer and adds a skip connection: output=layer(x)+x\text{output} = \text{layer}(x) + x. Handle the case where input and output dimensions don't match.
  3. Model Inspector: Write a function that takes any nn.Module and prints: total parameters, trainable parameters, number of layers, and memory usage in MB.
  4. Freeze/Unfreeze: Implement a function that freezes all layers except those whose names match a given pattern (e.g., freeze everything except classifier.*).

Challenge Exercise

Build a Mini nn.Module: Implement a simplified version of nn.Module from scratch that supports:

  • Parameter registration via descriptors or __setattr__
  • Submodule registration
  • Recursive parameters() and named_parameters()
  • train() and eval() mode switching
  • to() for device placement
  • state_dict() and load_state_dict()
🐍challenge_starter.py
1class MiniModule:
2    """Simplified nn.Module implementation."""
3
4    def __init__(self):
5        self._parameters = {}
6        self._modules = {}
7        self._buffers = {}
8        self.training = True
9
10    def __setattr__(self, name, value):
11        # TODO: Detect nn.Parameter and Module instances
12        # Register them appropriately
13        ...
14
15    def parameters(self):
16        # TODO: Yield all parameters recursively
17        ...
18
19    def to(self, device):
20        # TODO: Move all parameters and buffers
21        ...
22
23    def state_dict(self):
24        # TODO: Return dict of all parameters and buffers
25        ...
26
27# Test your implementation
28class LinearLayer(MiniModule):
29    def __init__(self, in_features, out_features):
30        super().__init__()
31        self.weight = Parameter(torch.randn(out_features, in_features))
32        self.bias = Parameter(torch.zeros(out_features))
33
34    def forward(self, x):
35        return x @ self.weight.T + self.bias

In the next section, we'll explore Linear Layers - the most fundamental building block of neural networks. We'll understand the mathematics of linear transformations and see how nn.Linear implements them efficiently.