Learning Objectives
By the end of this section, you will be able to:
- Understand transfer learning: Know why knowledge from one task helps with another
- Choose the right approach: Decide between feature extraction and fine-tuning based on your data
- Implement transfer learning: Freeze layers, replace heads, and tune learning rates in PyTorch
- Avoid common mistakes: Recognize and prevent typical transfer learning pitfalls
Practical Applications Coming Up
What is Transfer Learning?
Transfer learning is the practice of taking a model trained on one task (the source task) and reusing it for a different but related task (the target task). Instead of training from scratch, you start from a model that already has learned useful representations.
Analogy: Learning to play piano makes learning guitar easier because you've already developed fine motor control, rhythm sense, and music theory understanding. Similarly, a model trained to recognize animals can more easily learn to recognize dog breeds because it already understands edges, textures, and shapes.
The Traditional Machine Learning Problem
Traditional machine learning assumes:
- Training and test data come from the same distribution
- Each task requires training from scratch
- More data is always better
But in practice:
- Collecting labeled data is expensive and time-consuming
- Many domains have limited labeled data
- Training large models from scratch requires massive compute resources
Transfer learning solves these problems by reusing knowledge from data-rich tasks.
Why Transfer Learning Works
Transfer learning works because neural networks learn hierarchical representations:
The Representation Hierarchy
| Layer Depth | What It Learns | Transferability |
|---|---|---|
| Early layers | Low-level features (edges, textures, colors) | Highly transferable across tasks |
| Middle layers | Mid-level features (shapes, patterns, parts) | Moderately transferable |
| Late layers | High-level features (task-specific concepts) | Less transferable, often replaced |
The key insight: early layers learn features that are useful for many tasks, while later layers specialize for the specific training task.
Mathematical Intuition
Consider a network as a composition of feature extractors:
Where:
- : Generic, transferable features
- : Moderately specific features
- : Task-specific classifier
Transfer learning reuses (and possibly ) while replacing for the new task.
When to Use Transfer Learning
| Your Situation | Recommendation | Why |
|---|---|---|
| Small dataset, similar to source | Transfer learning (feature extraction) | Model already knows relevant features |
| Medium dataset, similar to source | Transfer learning (fine-tuning) | Adapt learned features to your specific task |
| Large dataset, similar to source | Transfer learning or train from scratch | Either works; transfer gives faster start |
| Small dataset, different from source | Transfer learning (careful fine-tuning) | Some features may still help; proceed cautiously |
| Large dataset, very different from source | Train from scratch | Source features may not transfer well |
The Data Quantity Rule of Thumb
Quick Check
You're building a medical X-ray classifier with 5,000 images. A model pretrained on natural images (ImageNet) is available. What should you do?
Types of Transfer Learning
There are two main approaches to transfer learning:
Feature Extraction
Freeze the pretrained model and only train a new classification head:
- All pretrained weights are fixed (frozen)
- Only the new head layer(s) are trained
- Fast training, low risk of overfitting
- Best when you have very limited data
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained model
5model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
6
7# Freeze all pretrained layers
8for param in model.parameters():
9 param.requires_grad = False
10
11# Replace the classification head
12num_features = model.fc.in_features
13model.fc = nn.Linear(num_features, num_classes) # Only this is trained
14
15# Only new head parameters are trainable
16trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
17total = sum(p.numel() for p in model.parameters())
18print(f"Training {trainable:,} / {total:,} parameters ({trainable/total:.1%})")Fine-Tuning
Unfreeze some or all pretrained layers and train with a low learning rate:
- Some or all pretrained weights are updated
- Typically use a much lower learning rate than training from scratch
- Can achieve better performance but risks overfitting
- Best when you have moderate data and task similarity
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained model
5model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
6
7# Replace the head first
8num_features = model.fc.in_features
9model.fc = nn.Linear(num_features, num_classes)
10
11# Strategy 1: Fine-tune all layers with low learning rate
12optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Lower than typical 1e-3
13
14# Strategy 2: Different learning rates for different layers
15param_groups = [
16 {'params': model.conv1.parameters(), 'lr': 1e-5}, # Freeze nearly
17 {'params': model.layer1.parameters(), 'lr': 1e-5},
18 {'params': model.layer2.parameters(), 'lr': 5e-5},
19 {'params': model.layer3.parameters(), 'lr': 1e-4},
20 {'params': model.layer4.parameters(), 'lr': 1e-4},
21 {'params': model.fc.parameters(), 'lr': 1e-3}, # Train head faster
22]
23optimizer = torch.optim.Adam(param_groups)Choosing Between Approaches
| Dataset Size | Task Similarity | Approach | Learning Rate |
|---|---|---|---|
| Very small (< 1k) | Similar | Feature extraction only | 1e-3 for head |
| Small (1k - 10k) | Similar | Fine-tune last few layers | 1e-4 overall |
| Medium (10k - 100k) | Similar | Fine-tune all layers | 1e-4 to 1e-5 |
| Small | Different | Feature extraction, careful eval | 1e-3 for head |
| Large | Different | Fine-tune or train from scratch | Experiment both |
Practical Considerations
Learning Rate Selection
The most important hyperparameter in fine-tuning is the learning rate:
- Too high: Destroys pretrained features ("catastrophic forgetting")
- Too low: Barely adapts to new task, wasting compute
- Just right: Gently adapts features while preserving useful knowledge
The 10x Rule
Gradual Unfreezing
A popular technique is to gradually unfreeze layers during training:
- Phase 1: Freeze all layers, train only the head (1-5 epochs)
- Phase 2: Unfreeze the last block, continue training with low LR (5-10 epochs)
- Phase 3: Optionally unfreeze more layers (if you have enough data)
1def unfreeze_layers(model, num_layers_to_unfreeze):
2 """Gradually unfreeze layers from the end."""
3 # Get all children modules
4 children = list(model.children())
5
6 # Freeze all first
7 for child in children:
8 for param in child.parameters():
9 param.requires_grad = False
10
11 # Unfreeze last N layers
12 for child in children[-num_layers_to_unfreeze:]:
13 for param in child.parameters():
14 param.requires_grad = True
15
16# Training loop with gradual unfreezing
17# Phase 1: Only head
18model.fc = nn.Linear(num_features, num_classes)
19for param in model.parameters():
20 param.requires_grad = False
21for param in model.fc.parameters():
22 param.requires_grad = True
23train_for_epochs(model, num_epochs=5, lr=1e-3)
24
25# Phase 2: Unfreeze last 2 layers
26unfreeze_layers(model, num_layers_to_unfreeze=2)
27train_for_epochs(model, num_epochs=10, lr=1e-4)Batch Normalization Considerations
BatchNorm During Transfer Learning
- If you keep BatchNorm frozen, use
model.eval()to use pretrained statistics - If you unfreeze BatchNorm, the running statistics will update during training
- For very small datasets, keeping BatchNorm frozen often works better
PyTorch Implementation
Here's a complete transfer learning workflow in PyTorch:
Common Pitfalls
1. Using Too High Learning Rate
Problem: Pretrained features are destroyed in the first few epochs.
Solution: Start with 10-100x lower learning rate than normal. Use learning rate finder on the new task.
2. Forgetting to Freeze BatchNorm
Problem: BatchNorm running statistics are corrupted by small dataset, leading to train/test discrepancy.
Solution: For small datasets, keep BatchNorm layers in eval mode even during training.
3. Not Matching Input Preprocessing
Problem: Using different normalization than the pretrained model expects.
Solution: Always use the same preprocessing (mean, std, size) that the pretrained model was trained with.
1from torchvision import transforms
2
3# ImageNet pretrained models expect this normalization
4imagenet_normalize = transforms.Normalize(
5 mean=[0.485, 0.456, 0.406],
6 std=[0.229, 0.224, 0.225]
7)
8
9# Correct preprocessing for transfer learning
10transform = transforms.Compose([
11 transforms.Resize(256),
12 transforms.CenterCrop(224), # ImageNet models expect 224x224
13 transforms.ToTensor(),
14 imagenet_normalize # MUST use pretrained normalization
15])4. Overfitting the Head
Problem: With very small data, even the head overfits quickly.
Solution: Add dropout to the head, use strong augmentation, consider simpler heads.
Summary
| Concept | Key Takeaway |
|---|---|
| Transfer Learning | Reuse pretrained models to save data and compute |
| Why It Works | Early layers learn generic features; later layers specialize |
| Feature Extraction | Freeze backbone, train only new head - best for small data |
| Fine-Tuning | Train all layers with low LR - best for medium data |
| Learning Rate | Use 10-100x lower than training from scratch |
| Gradual Unfreezing | Start frozen, progressively unfreeze layers |
| Preprocessing | Must match pretrained model's expected input format |
Coming Up: Practical Applications
Exercises
Conceptual Questions
- You have 500 images for a plant disease classification task. Would you use feature extraction or fine-tuning? Why?
- A colleague fine-tuned a pretrained model but got worse results than training from scratch. What might have gone wrong?
- Why do we typically use lower learning rates for pretrained layers than for the new classification head?
Solution Hints
- Q1: Feature extraction. With only 500 images, fine-tuning risks overfitting. Freeze the backbone and only train the head with strong augmentation.
- Q2: Possible causes: (1) Learning rate too high, destroyed features; (2) Wrong preprocessing/normalization; (3) Task very different from source task; (4) BatchNorm issues.
- Q3: Pretrained layers already contain useful features. High learning rates can destroy this knowledge (catastrophic forgetting). The new head needs larger updates since it starts randomly initialized.
Coding Exercises
- Implement gradual unfreezing: Create a training script that trains only the head for 5 epochs, then unfreezes the last block for 5 more epochs, then unfreezes everything for 10 epochs. Track validation accuracy at each phase.
- Compare approaches: On CIFAR-10, compare (a) training from scratch, (b) feature extraction with frozen ResNet18, and (c) fine-tuning. Use only 1000 training samples to make the comparison clearer.
- Learning rate experiment: Fine-tune a pretrained model with learning rates [1e-2, 1e-3, 1e-4, 1e-5]. Plot training and validation curves. Find the best rate and explain why the others fail.
Exercise Tips
- Exercise 1: Use PyTorch's
requires_gradto control which layers are trainable at each phase. - Exercise 2: Subsample CIFAR-10 to 1000 images (100 per class) for clear comparison. Track both train and val accuracy.
- Exercise 3: With 1e-2, you'll see catastrophic forgetting (val acc drops). With 1e-5, training will be too slow. Sweet spot is usually around 1e-4.
Congratulations! You've completed Chapter 9: Training Neural Networks. In the next part, we'll dive into Convolutional Neural Networks—starting with the convolution operation and building up to modern CNN architectures.