Chapter 9
18 min read
Section 61 of 178

Transfer Learning Fundamentals

Training Neural Networks

Learning Objectives

By the end of this section, you will be able to:

  1. Understand transfer learning: Know why knowledge from one task helps with another
  2. Choose the right approach: Decide between feature extraction and fine-tuning based on your data
  3. Implement transfer learning: Freeze layers, replace heads, and tune learning rates in PyTorch
  4. Avoid common mistakes: Recognize and prevent typical transfer learning pitfalls

Practical Applications Coming Up

This section covers the fundamentals of transfer learning applicable to any neural network. For detailed practical applications with CNNs and pretrained models (ResNet, VGG, etc.), see Chapter 12: CNNs in Practice.

What is Transfer Learning?

Transfer learning is the practice of taking a model trained on one task (the source task) and reusing it for a different but related task (the target task). Instead of training from scratch, you start from a model that already has learned useful representations.

Analogy: Learning to play piano makes learning guitar easier because you've already developed fine motor control, rhythm sense, and music theory understanding. Similarly, a model trained to recognize animals can more easily learn to recognize dog breeds because it already understands edges, textures, and shapes.

The Traditional Machine Learning Problem

Traditional machine learning assumes:

  • Training and test data come from the same distribution
  • Each task requires training from scratch
  • More data is always better

But in practice:

  • Collecting labeled data is expensive and time-consuming
  • Many domains have limited labeled data
  • Training large models from scratch requires massive compute resources

Transfer learning solves these problems by reusing knowledge from data-rich tasks.


Why Transfer Learning Works

Transfer learning works because neural networks learn hierarchical representations:

The Representation Hierarchy

Layer DepthWhat It LearnsTransferability
Early layersLow-level features (edges, textures, colors)Highly transferable across tasks
Middle layersMid-level features (shapes, patterns, parts)Moderately transferable
Late layersHigh-level features (task-specific concepts)Less transferable, often replaced

The key insight: early layers learn features that are useful for many tasks, while later layers specialize for the specific training task.

Mathematical Intuition

Consider a network as a composition of feature extractors:

f(x)=fhead(fbody(fbase(x)))f(\mathbf{x}) = f_{\text{head}}(f_{\text{body}}(f_{\text{base}}(\mathbf{x})))

Where:

  • fbasef_{\text{base}}: Generic, transferable features
  • fbodyf_{\text{body}}: Moderately specific features
  • fheadf_{\text{head}}: Task-specific classifier

Transfer learning reuses fbasef_{\text{base}} (and possibly fbodyf_{\text{body}}) while replacing fheadf_{\text{head}} for the new task.


When to Use Transfer Learning

Your SituationRecommendationWhy
Small dataset, similar to sourceTransfer learning (feature extraction)Model already knows relevant features
Medium dataset, similar to sourceTransfer learning (fine-tuning)Adapt learned features to your specific task
Large dataset, similar to sourceTransfer learning or train from scratchEither works; transfer gives faster start
Small dataset, different from sourceTransfer learning (careful fine-tuning)Some features may still help; proceed cautiously
Large dataset, very different from sourceTrain from scratchSource features may not transfer well

The Data Quantity Rule of Thumb

If you have fewer than 10,000 labeled examples, transfer learning almost always outperforms training from scratch. Modern pretrained models encode knowledge from millions of examples—you can't match that with limited data.

Quick Check

You're building a medical X-ray classifier with 5,000 images. A model pretrained on natural images (ImageNet) is available. What should you do?


Types of Transfer Learning

There are two main approaches to transfer learning:

Feature Extraction

Freeze the pretrained model and only train a new classification head:

  • All pretrained weights are fixed (frozen)
  • Only the new head layer(s) are trained
  • Fast training, low risk of overfitting
  • Best when you have very limited data
🐍feature_extraction.py
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained model
5model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
6
7# Freeze all pretrained layers
8for param in model.parameters():
9    param.requires_grad = False
10
11# Replace the classification head
12num_features = model.fc.in_features
13model.fc = nn.Linear(num_features, num_classes)  # Only this is trained
14
15# Only new head parameters are trainable
16trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
17total = sum(p.numel() for p in model.parameters())
18print(f"Training {trainable:,} / {total:,} parameters ({trainable/total:.1%})")

Fine-Tuning

Unfreeze some or all pretrained layers and train with a low learning rate:

  • Some or all pretrained weights are updated
  • Typically use a much lower learning rate than training from scratch
  • Can achieve better performance but risks overfitting
  • Best when you have moderate data and task similarity
🐍fine_tuning.py
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained model
5model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
6
7# Replace the head first
8num_features = model.fc.in_features
9model.fc = nn.Linear(num_features, num_classes)
10
11# Strategy 1: Fine-tune all layers with low learning rate
12optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Lower than typical 1e-3
13
14# Strategy 2: Different learning rates for different layers
15param_groups = [
16    {'params': model.conv1.parameters(), 'lr': 1e-5},  # Freeze nearly
17    {'params': model.layer1.parameters(), 'lr': 1e-5},
18    {'params': model.layer2.parameters(), 'lr': 5e-5},
19    {'params': model.layer3.parameters(), 'lr': 1e-4},
20    {'params': model.layer4.parameters(), 'lr': 1e-4},
21    {'params': model.fc.parameters(), 'lr': 1e-3},     # Train head faster
22]
23optimizer = torch.optim.Adam(param_groups)

Choosing Between Approaches

Dataset SizeTask SimilarityApproachLearning Rate
Very small (< 1k)SimilarFeature extraction only1e-3 for head
Small (1k - 10k)SimilarFine-tune last few layers1e-4 overall
Medium (10k - 100k)SimilarFine-tune all layers1e-4 to 1e-5
SmallDifferentFeature extraction, careful eval1e-3 for head
LargeDifferentFine-tune or train from scratchExperiment both

Practical Considerations

Learning Rate Selection

The most important hyperparameter in fine-tuning is the learning rate:

  • Too high: Destroys pretrained features ("catastrophic forgetting")
  • Too low: Barely adapts to new task, wasting compute
  • Just right: Gently adapts features while preserving useful knowledge

The 10x Rule

Start with a learning rate 10x lower than you would use for training from scratch. If training from scratch uses 1e-3, try 1e-4 for fine-tuning.

Gradual Unfreezing

A popular technique is to gradually unfreeze layers during training:

  1. Phase 1: Freeze all layers, train only the head (1-5 epochs)
  2. Phase 2: Unfreeze the last block, continue training with low LR (5-10 epochs)
  3. Phase 3: Optionally unfreeze more layers (if you have enough data)
🐍gradual_unfreezing.py
1def unfreeze_layers(model, num_layers_to_unfreeze):
2    """Gradually unfreeze layers from the end."""
3    # Get all children modules
4    children = list(model.children())
5
6    # Freeze all first
7    for child in children:
8        for param in child.parameters():
9            param.requires_grad = False
10
11    # Unfreeze last N layers
12    for child in children[-num_layers_to_unfreeze:]:
13        for param in child.parameters():
14            param.requires_grad = True
15
16# Training loop with gradual unfreezing
17# Phase 1: Only head
18model.fc = nn.Linear(num_features, num_classes)
19for param in model.parameters():
20    param.requires_grad = False
21for param in model.fc.parameters():
22    param.requires_grad = True
23train_for_epochs(model, num_epochs=5, lr=1e-3)
24
25# Phase 2: Unfreeze last 2 layers
26unfreeze_layers(model, num_layers_to_unfreeze=2)
27train_for_epochs(model, num_epochs=10, lr=1e-4)

Batch Normalization Considerations

BatchNorm During Transfer Learning

BatchNorm layers maintain running statistics from pretraining. When fine-tuning:
  • If you keep BatchNorm frozen, use model.eval() to use pretrained statistics
  • If you unfreeze BatchNorm, the running statistics will update during training
  • For very small datasets, keeping BatchNorm frozen often works better

PyTorch Implementation

Here's a complete transfer learning workflow in PyTorch:

Complete Transfer Learning Implementation
🐍transfer_learning_model.py
17Remove Original Head

Replace the pretrained classification head with Identity() to get raw features. We'll add our own head.

20Custom Head

Add dropout for regularization and a hidden layer. This is trainable from scratch.

45Differential Learning Rates

Use lower learning rate for pretrained backbone (1e-5) and higher for new head (1e-4). This prevents destroying pretrained features.

56 lines without explanation
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torchvision import models
5
6class TransferLearningModel(nn.Module):
7    """Wrapper for transfer learning with flexible unfreezing."""
8
9    def __init__(self, num_classes, pretrained_model='resnet18'):
10        super().__init__()
11
12        # Load pretrained model
13        if pretrained_model == 'resnet18':
14            self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
15            num_features = self.backbone.fc.in_features
16            self.backbone.fc = nn.Identity()  # Remove original head
17        else:
18            raise ValueError(f"Unknown model: {pretrained_model}")
19
20        # New classification head
21        self.head = nn.Sequential(
22            nn.Dropout(0.3),
23            nn.Linear(num_features, 256),
24            nn.ReLU(),
25            nn.Dropout(0.2),
26            nn.Linear(256, num_classes)
27        )
28
29        # Start with backbone frozen
30        self.freeze_backbone()
31
32    def freeze_backbone(self):
33        """Freeze all backbone parameters."""
34        for param in self.backbone.parameters():
35            param.requires_grad = False
36
37    def unfreeze_backbone(self, from_layer=None):
38        """Unfreeze backbone (optionally from a specific layer)."""
39        for param in self.backbone.parameters():
40            param.requires_grad = True
41
42    def forward(self, x):
43        features = self.backbone(x)
44        return self.head(features)
45
46# Usage
47model = TransferLearningModel(num_classes=10)
48
49# Phase 1: Feature extraction (head only)
50optimizer = optim.Adam(model.head.parameters(), lr=1e-3)
51# Train for a few epochs...
52
53# Phase 2: Fine-tuning (all layers)
54model.unfreeze_backbone()
55optimizer = optim.Adam([
56    {'params': model.backbone.parameters(), 'lr': 1e-5},
57    {'params': model.head.parameters(), 'lr': 1e-4}
58])
59# Continue training...

Common Pitfalls

1. Using Too High Learning Rate

Problem: Pretrained features are destroyed in the first few epochs.

Solution: Start with 10-100x lower learning rate than normal. Use learning rate finder on the new task.

2. Forgetting to Freeze BatchNorm

Problem: BatchNorm running statistics are corrupted by small dataset, leading to train/test discrepancy.

Solution: For small datasets, keep BatchNorm layers in eval mode even during training.

3. Not Matching Input Preprocessing

Problem: Using different normalization than the pretrained model expects.

Solution: Always use the same preprocessing (mean, std, size) that the pretrained model was trained with.

🐍preprocessing.py
1from torchvision import transforms
2
3# ImageNet pretrained models expect this normalization
4imagenet_normalize = transforms.Normalize(
5    mean=[0.485, 0.456, 0.406],
6    std=[0.229, 0.224, 0.225]
7)
8
9# Correct preprocessing for transfer learning
10transform = transforms.Compose([
11    transforms.Resize(256),
12    transforms.CenterCrop(224),  # ImageNet models expect 224x224
13    transforms.ToTensor(),
14    imagenet_normalize  # MUST use pretrained normalization
15])

4. Overfitting the Head

Problem: With very small data, even the head overfits quickly.

Solution: Add dropout to the head, use strong augmentation, consider simpler heads.


Summary

ConceptKey Takeaway
Transfer LearningReuse pretrained models to save data and compute
Why It WorksEarly layers learn generic features; later layers specialize
Feature ExtractionFreeze backbone, train only new head - best for small data
Fine-TuningTrain all layers with low LR - best for medium data
Learning RateUse 10-100x lower than training from scratch
Gradual UnfreezingStart frozen, progressively unfreeze layers
PreprocessingMust match pretrained model's expected input format

Coming Up: Practical Applications

In Chapter 12 (CNNs in Practice), we'll apply these concepts with specific pretrained architectures like ResNet, VGG, and EfficientNet, including code for image classification, object detection, and more.

Exercises

Conceptual Questions

  1. You have 500 images for a plant disease classification task. Would you use feature extraction or fine-tuning? Why?
  2. A colleague fine-tuned a pretrained model but got worse results than training from scratch. What might have gone wrong?
  3. Why do we typically use lower learning rates for pretrained layers than for the new classification head?

Solution Hints

  1. Q1: Feature extraction. With only 500 images, fine-tuning risks overfitting. Freeze the backbone and only train the head with strong augmentation.
  2. Q2: Possible causes: (1) Learning rate too high, destroyed features; (2) Wrong preprocessing/normalization; (3) Task very different from source task; (4) BatchNorm issues.
  3. Q3: Pretrained layers already contain useful features. High learning rates can destroy this knowledge (catastrophic forgetting). The new head needs larger updates since it starts randomly initialized.

Coding Exercises

  1. Implement gradual unfreezing: Create a training script that trains only the head for 5 epochs, then unfreezes the last block for 5 more epochs, then unfreezes everything for 10 epochs. Track validation accuracy at each phase.
  2. Compare approaches: On CIFAR-10, compare (a) training from scratch, (b) feature extraction with frozen ResNet18, and (c) fine-tuning. Use only 1000 training samples to make the comparison clearer.
  3. Learning rate experiment: Fine-tune a pretrained model with learning rates [1e-2, 1e-3, 1e-4, 1e-5]. Plot training and validation curves. Find the best rate and explain why the others fail.

Exercise Tips

  • Exercise 1: Use PyTorch's requires_grad to control which layers are trainable at each phase.
  • Exercise 2: Subsample CIFAR-10 to 1000 images (100 per class) for clear comparison. Track both train and val accuracy.
  • Exercise 3: With 1e-2, you'll see catastrophic forgetting (val acc drops). With 1e-5, training will be too slow. Sweet spot is usually around 1e-4.

Congratulations! You've completed Chapter 9: Training Neural Networks. In the next part, we'll dive into Convolutional Neural Networks—starting with the convolution operation and building up to modern CNN architectures.