Why Transfer Learning Works
In Section 2, we saw that ResNet trained on ImageNet achieves superhuman accuracy on 1,000 categories. But most real-world problems are not “classify these 1,000 ImageNet categories.” You might need to classify 5 types of flowers, 3 types of skin lesions, or 20 types of manufactured defects — tasks with far less training data than ImageNet's 1.2 million images.
Transfer learning solves this by reusing a model pretrained on a large dataset (like ImageNet) and adapting it to your specific task. Instead of training from random weights on 500 images, you start from weights that already understand visual concepts from 1.2M images.
The Core Idea: A CNN trained on ImageNet does not just learn to classify dogs and cars. It learns a hierarchy of visual features — edges, textures, shapes, parts, objects — that are useful for any visual task. Transfer learning reuses these universal features.
The Feature Hierarchy Insight
Research by Zeiler & Fergus (2014) showed that CNN layers learn increasingly abstract features. The deeper you go, the more task-specific the features become:
| Layer Group | What It Learns | Transferability |
|---|---|---|
| Early layers (conv1–conv2) | Edges, corners, color gradients | Highly universal — useful for ANY visual task |
| Middle layers (conv3–conv4) | Textures, patterns, simple shapes | Very transferable across most domains |
| Later layers (conv5–layer3) | Object parts, complex shapes | Moderately transferable — may need fine-tuning |
| Final layers (layer4–fc) | Task-specific compositions | Least transferable — usually replaced |
This gradient of transferability has a profound practical implication: the early layers of a pretrained CNN are a near-universal feature extractor. Whether your task involves flowers, medical images, satellite photos, or factory defects — the edge detectors, texture analyzers, and shape recognizers learned from ImageNet provide an excellent starting point.
Feature Hierarchy in CNNs
How neural networks build complex features from simple ones
Input
— Raw pixelsLayer 1
— Edges & GradientsLayer 2
— Textures & PatternsLayer 3
— Object PartsLayer 4+
— Objects & ScenesKey Insight: Each layer combines features from the previous layer. Early layers detect low-level features; deeper layers capture high-level concepts.
The diagram above is the pattern Zeiler & Fergus (2014) documented experimentally: early layers respond to Gabor-like edges and colour blobs, middle layers respond to textures and simple parts, later layers respond to whole objects. Yosinski et al. (2014) quantified the effect: freezing the first layer costs almost no accuracy on a new task, freezing the last layer costs a lot. This is the empirical basis of every transfer-learning strategy we are about to use.
The Mathematics of Why It Works
Formally, consider a pretrained model with optimal ImageNet weights . For a new task with a small dataset , training from scratch yields weights that likely overfit. But starting from and fine-tuning gives that:
- Starts in a good region of the loss landscape (pretrained features are already useful)
- Needs fewer gradient steps to converge (5 epochs instead of 100)
- Generalizes better because the pretrained features provide implicit regularization (the features were validated on 1.2M diverse images)
Two Strategies: Extract vs Fine-Tune
There are two main approaches to transfer learning, and the right choice depends on your dataset size and how similar your domain is to ImageNet:
Strategy 1: Feature Extraction
Freeze the entire pretrained backbone. Only train a new classification head. The pretrained model acts as a fixed feature extractor.
- When to use: Small dataset (<1,000 images per class), or when your domain is similar to ImageNet (natural photos of objects, animals, scenes)
- Pros: Fast training, no risk of overfitting, minimal compute
- Cons: Cannot adapt features to domain-specific patterns
Strategy 2: Fine-Tuning
Unfreeze some or all backbone layers and train them with a small learning rate. The pretrained features are gently adapted to your domain.
- When to use: Medium-to-large dataset (>1,000 images per class), or when your domain is different from ImageNet (medical images, aerial photos, microscopy)
- Pros: Higher accuracy, adapts features to your domain
- Cons: Risk of overfitting if dataset is too small, slower training, needs careful learning rate tuning
| Factor | Feature Extraction | Fine-Tuning |
|---|---|---|
| Trainable parameters | Only new head (~2K) | All layers (~11M) |
| Training time | Minutes | Hours |
| Min dataset size | ~100 images/class | ~500–1000 images/class |
| Risk of overfitting | Very low | Moderate (need regularization) |
| Accuracy ceiling | Good (95%+) | Best possible (98%+) |
| Learning rate | Normal (1e-3) | Differential (1e-4 backbone, 1e-3 head) |
Quick Check
You have 50 images per class and your images are regular photos of animals. Which transfer learning strategy should you use?
Loading Pretrained Models
PyTorch provides a rich model zoo through . Let's load a pretrained ResNet-18 and examine its structure:
Strategy 1: Feature Extraction
The simplest form of transfer learning: freeze the pretrained backbone, replace the classification head, and train only the new head. Three lines of code transform a 1000-class ImageNet model into a 5-class flower classifier:
The key result: 96.5% accuracy on 5 flower classes by training only 2,565 parameters. The pretrained backbone provides such rich features that a simple linear classifier on top achieves excellent performance. We trained 0.02% of the model and got 96.5% accuracy.
Strategy 2: Fine-Tuning
When you have more data and want to squeeze out every last percentage point of accuracy, fine-tuning adapts the pretrained features to your specific domain. The key technique is differential learning rates: a small rate for the pretrained backbone (preserve features) and a larger rate for the new head (learn fast).
The Warmup + Fine-Tune Pattern
The two-phase approach is critical for stable fine-tuning:
- Phase 1 (Warmup): Freeze backbone, train only the new FC head for 2\u20133 epochs. This gives the head reasonable weights so that its gradients are meaningful when they flow back through the backbone.
- Phase 2 (Fine-tune): Unfreeze backbone with a small learning rate (10\u00d7 smaller than head). Train all layers together. The backbone features adapt gently to the new domain.
BatchNorm Under Fine-Tuning
Section 2 introduced BatchNorm and its two modes. Here is where that distinction becomes a production gotcha. A pretrained ResNet carries running statistics learned from ImageNet — the mean and variance of every channel across 1.2 M natural images. Those statistics are not learnable parameters, so param.requires_grad = False does nothing to them. They keep drifting whenever the layer is in training mode and the optimiser can never reset them.
The symptom: you freeze the backbone, train the head, achieve 95% validation accuracy, then unfreeze for fine-tuning — and accuracy drops. You check gradients, losses, learning rates. They look fine. What changed? The BatchNorm running stats were silently being replaced by statistics from your much smaller fine-tuning batches, which are noisier and have a different distribution than ImageNet. The feature extractor's normalisation assumption is now broken.
The Fix: Freeze BN Stats Explicitly
1def freeze_bn_running_stats(module: nn.Module) -> None:
2 """Keep BatchNorm layers in eval() mode even when the overall model is in train().
3
4 This pins running_mean / running_var to their pretrained (ImageNet) values
5 while still allowing gamma and beta to receive gradient updates if they are
6 not separately frozen.
7 """
8 for m in module.modules():
9 if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
10 m.eval() # switch BN to running-stats mode
11 # Optional: also freeze the learnable gamma/beta if you want BN
12 # to behave as an entirely fixed, pretrained normalisation.
13 # for p in m.parameters():
14 # p.requires_grad = False
15
16
17# Usage inside the fine-tuning loop:
18model.train() # enables dropout, sets the default mode
19freeze_bn_running_stats(model) # OVERRIDE just the BN layers back to eval
20for x, y in train_loader:
21 optimizer.zero_grad()
22 out = model(x) # BN uses pretrained running stats, not batch stats
23 loss = criterion(out, y)
24 loss.backward()
25 optimizer.step()The call order matters. model.train() sets every submodule to training mode. Then freeze_bn_running_stats walks the module tree and puts each BatchNorm layer back into eval() mode. The rest of the model (conv, linear, dropout) stays in training mode as intended.
Practical Guidelines
Here is a decision flowchart for transfer learning in practice:
| Your Situation | Recommended Approach | Learning Rate |
|---|---|---|
| Very small data (<100/class), similar domain | Feature extraction | 1e-3 (head only) |
| Small data (100–1000/class), similar domain | Feature extraction | 1e-3 (head only) |
| Medium data (1000–10000/class), similar domain | Fine-tune last 1–2 layers | 1e-4 backbone, 1e-3 head |
| Large data (>10000/class), similar domain | Fine-tune all layers | 1e-4 backbone, 1e-3 head |
| Any size, very different domain (medical, satellite) | Fine-tune all layers + augmentation | 1e-5 early, 1e-4 late, 1e-3 head |
| Huge data (millions), very different domain | Train from scratch (or pretrain on your domain) | 1e-3 everywhere |
Common Pitfalls
- Wrong preprocessing: Always use the same normalization as the pretrained model. ImageNet models expect mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]. Using different values will destroy feature quality.
- Too-high backbone learning rate: If you fine-tune with lr=1e-2, you will erase the pretrained features in the first epoch. Use 1e-4 or smaller for the backbone.
- No warmup: Skipping the head warmup phase lets random gradients corrupt pretrained features. Always train the head for 1\u20132 epochs first.
- Forgetting model.eval(): During inference, you must call model.eval() to disable dropout and switch batch norm to running statistics. This is critical for pretrained models with batch norm.
- Grayscale input to RGB model: If your images are grayscale, repeat the channel 3 times: . The pretrained model expects 3 channels.
Chapter Summary: In this chapter, we built a complete CNN from scratch (Section 1), traced the evolution of CNN architectures from LeNet to ResNet (Section 2), and learned to leverage pretrained models through transfer learning (Section 3). The practical takeaway: almost never train a CNN from scratch. Start with a pretrained backbone, adapt it to your task, and achieve excellent results with a fraction of the data and compute.
References
The transferability claim — that early-layer features are near-universal and late-layer features are task-specific — rests on controlled experiments by Yosinski et al. (2014) and the visualisation work of Zeiler & Fergus (2013). They are the papers to cite for anything beyond informal intuition.
- Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K. & Fei-Fei, L. (2009). ImageNet: A Large-Scale Hierarchical Image Database. CVPR 2009. — The 1.2M-image, 1000-class dataset the pretrained ResNet we load was trained on.
- Zeiler, M. D. & Fergus, R. (2013). Visualizing and Understanding Convolutional Networks. ECCV 2014 / arXiv:1311.2901. — The classic layer-by-layer feature visualisation (Gabor-like filters in layer 1, object parts in layer 5).
- Yosinski, J., Clune, J., Bengio, Y. & Lipson, H. (2014). How transferable are features in deep neural networks? NeurIPS 2014 / arXiv:1411.1792. — Quantifies which layers transfer and why.
- He, K., Zhang, X., Ren, S. & Sun, J. (2015). Deep Residual Learning for Image Recognition. CVPR 2016 / arXiv:1512.03385. — The ResNet-18 weights we load.
- Ioffe, S. & Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML 2015 / arXiv:1502.03167. — The source of the / running-stats caveat under fine-tuning.
- PyTorch documentation. torchvision.models. pytorch.org/vision/stable/models.html — Weight enums and pretrained weights used in this section.