Learning Objectives
By the end of this section, you will:
- Understand why CNN features transfer across domains and tasks
- Know how to choose between feature extraction and fine-tuning based on your data
- Select the right pretrained model from the PyTorch model zoo
- Implement complete transfer learning pipelines with proper preprocessing
- Apply domain-specific strategies for medical, satellite, and specialized imagery
- Avoid common pitfalls that destroy pretrained knowledge
Building on Fundamentals
Why Transfer Learning for CNNs?
Convolutional Neural Networks are particularly well-suited for transfer learning because of how they learn visual features. This isn't accidental—it's a direct consequence of the hierarchical nature of visual perception.
The ImageNet Effect
Models trained on ImageNet (1.2 million images, 1000 classes) learn incredibly rich visual representations:
| Training Data | What the Model Learns | Transfer Value |
|---|---|---|
| 1.2M diverse images | Generic visual primitives | Universal edge, texture, shape detectors |
| 1000 categories | Hierarchical object parts | Eyes, wheels, faces, fur patterns |
| Natural distribution | Real-world visual statistics | Understanding of occlusion, lighting, pose |
Key Insight: A model trained to distinguish 1000 object categories necessarily learns features that are useful for any visual recognition task. The early layers learn features so general they transfer even to radically different domains like medical X-rays or satellite imagery.
The Economics of Transfer Learning
Consider what it would take to train a model from scratch for your specific task:
| Resource | Training from Scratch | Transfer Learning |
|---|---|---|
| Data needed | 100K+ labeled images | 1K-10K images |
| Training time | Days to weeks | Hours |
| GPU cost | $100-$1000+ | $5-$50 |
| Risk of failure | High (many hyperparameters) | Low (pretrained baseline) |
| Final accuracy | Unknown until trained | Strong baseline guaranteed |
The Practical Reality
The CNN Feature Hierarchy
Understanding what each layer of a CNN learns helps us make intelligent decisions about what to freeze and what to fine-tune.
Layer-by-Layer Feature Analysis
Research by Zeiler & Fergus (2014) and others has shown that CNN layers learn increasingly abstract features:
| Layer Depth | Features Learned | Transferability | Action |
|---|---|---|---|
| Conv 1-2 | Edges, gradients, colors, simple textures | ~95% (highly universal) | Almost always freeze |
| Conv 3-4 | Corners, contours, complex textures | ~80% (mostly universal) | Usually freeze |
| Conv 5-7 | Object parts, semantic patterns | ~60% (domain-dependent) | Fine-tune if enough data |
| Conv 8+ | Whole objects, high-level concepts | ~30% (task-specific) | Often fine-tune or replace |
| FC layers | Class-specific decision boundaries | ~0% (ImageNet-specific) | Always replace |
Mathematical Perspective
We can think of a pretrained CNN as a composition of learned functions:
Where:
- : Early convolutional layers — highly transferable edge/texture detectors
- : Middle layers — moderately transferable part detectors
- : Later layers — domain-specific semantic features
- : Final fully-connected layers — task-specific, must be replaced
For transfer learning, we keep fixed, optionally fine-tune and , and replace entirely.
Interactive: Feature Transfer
Explore how CNN features transfer across different domains. Select a target domain and see how the transferability of each layer changes:
Loading interactive demo...
Quick Check
You're building a satellite imagery classifier. Based on the feature hierarchy, which layers would benefit most from fine-tuning?
Transfer Learning Strategies
The right strategy depends on two key factors: dataset size and domain similarity. Let's analyze each approach:
Strategy 1: Feature Extraction
When: Small dataset (<5K images) or very similar domain
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained ResNet
5model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
6
7# Freeze ALL backbone parameters
8for param in model.parameters():
9 param.requires_grad = False
10
11# Replace classifier head
12num_features = model.fc.in_features
13model.fc = nn.Sequential(
14 nn.Dropout(0.5),
15 nn.Linear(num_features, 256),
16 nn.ReLU(),
17 nn.Dropout(0.3),
18 nn.Linear(256, num_classes)
19)
20
21# Only head parameters are trainable
22optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)Strategy 2: Discriminative Fine-Tuning
When: Medium dataset (5K-50K images), similar domain
1# Different learning rates for different depths
2# Key insight: earlier = smaller LR, later = larger LR
3
4param_groups = [
5 # Earliest layers: almost frozen
6 {'params': model.conv1.parameters(), 'lr': 1e-6},
7 {'params': model.bn1.parameters(), 'lr': 1e-6},
8 {'params': model.layer1.parameters(), 'lr': 1e-6},
9
10 # Early-mid layers: very small updates
11 {'params': model.layer2.parameters(), 'lr': 5e-6},
12
13 # Mid layers: moderate updates
14 {'params': model.layer3.parameters(), 'lr': 1e-5},
15
16 # Late layers: larger updates
17 {'params': model.layer4.parameters(), 'lr': 5e-5},
18
19 # New head: normal learning rate
20 {'params': model.fc.parameters(), 'lr': 1e-3},
21]
22
23optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)Strategy 3: Gradual Unfreezing
When: Any dataset size, safest approach
1def get_layer_groups(model):
2 """Get ResNet layer groups from shallow to deep."""
3 return [
4 [model.conv1, model.bn1],
5 [model.layer1],
6 [model.layer2],
7 [model.layer3],
8 [model.layer4],
9 [model.fc]
10 ]
11
12def unfreeze_layer_group(model, group_idx, layer_groups):
13 """Unfreeze a specific layer group."""
14 for param in layer_groups[group_idx].parameters():
15 param.requires_grad = True
16
17# Training schedule
18layer_groups = get_layer_groups(model)
19
20# Phase 1: Only head (5 epochs, lr=1e-3)
21for group in layer_groups[:-1]:
22 for param in group.parameters():
23 param.requires_grad = False
24train(model, epochs=5, lr=1e-3)
25
26# Phase 2: Unfreeze layer4 (5 epochs, lr=1e-4)
27unfreeze_layer_group(model, 4, layer_groups)
28train(model, epochs=5, lr=1e-4)
29
30# Phase 3: Unfreeze layer3 (10 epochs, lr=1e-5)
31unfreeze_layer_group(model, 3, layer_groups)
32train(model, epochs=10, lr=1e-5)The 10x Learning Rate Rule
Interactive: Strategy Decider
Use this tool to determine the best transfer learning strategy for your situation. Adjust the sliders to match your dataset characteristics:
Loading interactive demo...
Pretrained Model Zoo
PyTorch's torchvision.models provides a rich collection of pretrained CNN architectures. Choosing the right one depends on your constraints:
Model Families Overview
| Family | Key Innovation | Best For | Trade-off |
|---|---|---|---|
| ResNet | Skip connections | General purpose, well-studied | Balanced accuracy/speed |
| VGG | Simple stacked convs | Feature visualization, style transfer | Large, slower |
| DenseNet | Dense connectivity | Parameter efficiency | Memory intensive |
| EfficientNet | Compound scaling | Mobile/edge, high accuracy | Best efficiency curve |
| ConvNeXt | Modernized ResNet | Competing with ViT | Excellent performance |
| Vision Transformer | Self-attention | Very large datasets | Needs more data to transfer well |
Quick Selection Guide
- Quick prototyping: ResNet-18 (11M params, fast)
- Production baseline: ResNet-50 (25M params, reliable)
- Mobile deployment: EfficientNet-B0 (5M params, efficient)
- Maximum accuracy: EfficientNet-B4 or ConvNeXt-Base
- Research/large data: ViT-B/16 (86M params, scales well)
Interactive: Model Comparison
Compare pretrained models across accuracy, parameters, and computational cost. Select models to see detailed usage instructions:
Loading interactive demo...
Complete Implementation
Here's a production-ready transfer learning pipeline that incorporates all best practices:
Domain-Specific Transfer
Different domains require different transfer strategies. Here's guidance for common specialized applications:
Medical Imaging
| Consideration | Recommendation |
|---|---|
| Domain gap | Moderate - X-rays/CT differ from natural images but share structural features |
| Data availability | Usually small (hundreds to low thousands) |
| Strategy | Feature extraction or careful gradual unfreezing |
| Preprocessing | Grayscale images need 3-channel conversion; consider CLAHE |
| Regularization | Heavy dropout (0.5+), strong augmentation |
1# Medical imaging specific preprocessing
2medical_transforms = transforms.Compose([
3 transforms.Resize(256),
4 transforms.CenterCrop(224),
5 transforms.Grayscale(num_output_channels=3), # Convert to RGB
6 transforms.ToTensor(),
7 transforms.Normalize(
8 mean=[0.485, 0.456, 0.406], # Still use ImageNet stats
9 std=[0.229, 0.224, 0.225]
10 ),
11])
12
13# Consider specialized pretrained models
14# RadImageNet, CheXNet, or DenseNet-121 pretrained on chest X-raysSatellite/Aerial Imagery
| Consideration | Recommendation |
|---|---|
| Domain gap | Moderate - different viewpoint, but textures/edges transfer |
| Input channels | May have >3 channels (infrared, multispectral) |
| Strategy | Partial fine-tuning usually works well |
| Resolution | Often higher resolution than ImageNet; use larger input sizes |
Document/Text Images
| Consideration | Recommendation |
|---|---|
| Domain gap | Large - very different from natural images |
| Strategy | Feature extraction may struggle; consider full fine-tuning |
| Alternative | Consider specialized models (document-pretrained) if available |
| Input | Binary/grayscale; ensure proper normalization |
Domain Gap Awareness
- Try transfer learning first as a baseline
- If results are poor, consider training from scratch with more data
- Search for domain-specific pretrained models
- Use self-supervised pretraining on your unlabeled domain data
Best Practices
1. Always Match Preprocessing
The most common source of poor transfer learning results is mismatched preprocessing:
1# CORRECT: Use weights-specific transforms
2from torchvision.models import ResNet50_Weights
3
4weights = ResNet50_Weights.IMAGENET1K_V2
5preprocess = weights.transforms()
6
7# Or manually specify ImageNet statistics
8preprocess = transforms.Compose([
9 transforms.Resize(256),
10 transforms.CenterCrop(224),
11 transforms.ToTensor(),
12 transforms.Normalize(
13 mean=[0.485, 0.456, 0.406],
14 std=[0.229, 0.224, 0.225]
15 )
16])
17
18# WRONG: Using different normalization
19# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Don't do this!2. Handle BatchNorm Carefully
BatchNorm layers store running statistics from ImageNet. For very small datasets:
1# Option 1: Keep BatchNorm in eval mode during training
2def train_with_frozen_bn(model, dataloader, optimizer, criterion):
3 model.train()
4
5 # But keep BatchNorm layers in eval mode
6 for module in model.modules():
7 if isinstance(module, nn.BatchNorm2d):
8 module.eval()
9
10 # Training loop as usual...
11
12# Option 2: Freeze BatchNorm parameters
13for module in model.modules():
14 if isinstance(module, nn.BatchNorm2d):
15 module.weight.requires_grad = False
16 module.bias.requires_grad = False3. Use Proper Learning Rate Scheduling
1from torch.optim.lr_scheduler import (
2 CosineAnnealingLR,
3 OneCycleLR,
4 ReduceLROnPlateau
5)
6
7# Cosine annealing - smooth decay
8scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
9
10# One cycle - fast convergence
11scheduler = OneCycleLR(
12 optimizer,
13 max_lr=1e-3,
14 epochs=num_epochs,
15 steps_per_epoch=len(train_loader),
16 pct_start=0.3, # Warmup for 30% of training
17)
18
19# Reduce on plateau - adaptive
20scheduler = ReduceLROnPlateau(
21 optimizer,
22 mode='min',
23 factor=0.5,
24 patience=5,
25 min_lr=1e-7
26)4. Monitor for Catastrophic Forgetting
Watch for signs that pretrained features are being destroyed:
- Training loss drops quickly but validation loss increases
- Validation accuracy peaks early then degrades
- Activations become saturated (all zeros or very large)
If you see these signs: reduce learning rate by 10x and try again.
Common Mistakes
Mistake 1: Wrong Input Size
1# WRONG: Using arbitrary input size
2model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
3x = torch.randn(1, 3, 128, 128) # Too small!
4
5# CORRECT: Use model's expected size (224 for most, 299 for Inception, 380 for EfficientNet-B4)
6x = torch.randn(1, 3, 224, 224)Mistake 2: Forgetting to Replace the Head
1# WRONG: Training with 1000-class output
2model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
3# Oops, model.fc outputs 1000 classes, not our num_classes!
4
5# CORRECT: Replace the classification head
6model.fc = nn.Linear(model.fc.in_features, num_classes)Mistake 3: Training All Layers with Same LR
1# WRONG: Same learning rate for everything
2optimizer = Adam(model.parameters(), lr=1e-3) # Will destroy early features!
3
4# CORRECT: Lower LR for pretrained layers
5optimizer = Adam([
6 {'params': model.conv1.parameters(), 'lr': 1e-5},
7 {'params': model.fc.parameters(), 'lr': 1e-3}
8])Mistake 4: Not Using Data Augmentation
With small datasets, data augmentation is critical to prevent overfitting:
1from torchvision import transforms
2
3train_transforms = transforms.Compose([
4 transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
5 transforms.RandomHorizontalFlip(),
6 transforms.RandomRotation(15),
7 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
8 transforms.ToTensor(),
9 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
10 transforms.RandomErasing(p=0.1), # Cutout-style augmentation
11])Summary
| Concept | Key Takeaway |
|---|---|
| Why CNNs transfer | Hierarchical features: early layers learn universal visual primitives |
| Feature extraction | Freeze backbone, train head only - best for small data (<5K) |
| Fine-tuning | Train all/some layers with low LR - best for medium data (5K-50K) |
| Discriminative LR | Early layers: 1e-6, Late layers: 1e-4, Head: 1e-3 |
| Gradual unfreezing | Safest approach - start frozen, progressively unfreeze |
| Preprocessing | MUST match pretrained model (ImageNet: mean=[0.485, 0.456, 0.406]) |
| Model selection | ResNet-50 for baseline, EfficientNet for efficiency, ConvNeXt for SOTA |
| BatchNorm | Keep in eval mode for very small datasets |
Exercises
Conceptual Questions
- Explain why early CNN layers transfer better than later layers. What would you expect if you tried to transfer the final convolutional layer from a model trained on natural images to a document scanner classifier?
- A colleague trained a transfer learning model and found that training accuracy was 99% but validation accuracy was only 45%. What went wrong, and how would you fix it?
- You have 500 labeled medical X-ray images. Describe your transfer learning strategy, including which layers to freeze, learning rate choices, and what preprocessing you would use.
Coding Exercises
- Model comparison: Implement transfer learning with ResNet-18, ResNet-50, and EfficientNet-B0 on CIFAR-10. Compare training time, final accuracy, and inference speed. Use feature extraction mode.
- Strategy comparison: On a subset of 1,000 CIFAR-10 images, compare: (a) feature extraction, (b) discriminative fine-tuning, (c) gradual unfreezing. Plot learning curves for each.
- Learning rate experiment: Fine-tune ResNet-50 with learning rates [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]. Plot training and validation curves. Identify which rates cause catastrophic forgetting and which undertrain.
Exercise Tips
- Exercise 1: Use
torchvision.datasets.CIFAR10with proper transforms. Resize to 224x224 for fair comparison. - Exercise 2: Subsample CIFAR-10 to 1000 images (100 per class). Track both train and val accuracy per epoch.
- Exercise 3: 1e-2 should show forgetting (val drops after initial rise), 1e-6 should undertrain (slow improvement). Sweet spot is usually 1e-4 to 1e-5.
In the next section, we'll dive deeper into using pretrained models with specific code recipes for common tasks like image classification, feature extraction, and fine-tuning on custom datasets.