Chapter 12
20 min read
Section 76 of 178

Transfer Learning

CNNs in Practice

Learning Objectives

By the end of this section, you will:

  • Understand why CNN features transfer across domains and tasks
  • Know how to choose between feature extraction and fine-tuning based on your data
  • Select the right pretrained model from the PyTorch model zoo
  • Implement complete transfer learning pipelines with proper preprocessing
  • Apply domain-specific strategies for medical, satellite, and specialized imagery
  • Avoid common pitfalls that destroy pretrained knowledge

Building on Fundamentals

This section applies the transfer learning principles from Chapter 9, Section 8 specifically to CNNs. We focus on practical implementation with pretrained vision models. If you haven't read the fundamentals section, review it for background on why transfer learning works.

Why Transfer Learning for CNNs?

Convolutional Neural Networks are particularly well-suited for transfer learning because of how they learn visual features. This isn't accidental—it's a direct consequence of the hierarchical nature of visual perception.

The ImageNet Effect

Models trained on ImageNet (1.2 million images, 1000 classes) learn incredibly rich visual representations:

Training DataWhat the Model LearnsTransfer Value
1.2M diverse imagesGeneric visual primitivesUniversal edge, texture, shape detectors
1000 categoriesHierarchical object partsEyes, wheels, faces, fur patterns
Natural distributionReal-world visual statisticsUnderstanding of occlusion, lighting, pose
Key Insight: A model trained to distinguish 1000 object categories necessarily learns features that are useful for any visual recognition task. The early layers learn features so general they transfer even to radically different domains like medical X-rays or satellite imagery.

The Economics of Transfer Learning

Consider what it would take to train a model from scratch for your specific task:

ResourceTraining from ScratchTransfer Learning
Data needed100K+ labeled images1K-10K images
Training timeDays to weeksHours
GPU cost$100-$1000+$5-$50
Risk of failureHigh (many hyperparameters)Low (pretrained baseline)
Final accuracyUnknown until trainedStrong baseline guaranteed

The Practical Reality

In almost every real-world computer vision project, transfer learning from ImageNet-pretrained models outperforms training from scratch. The exception is when you have millions of domain-specific images (like autonomous driving datasets).

The CNN Feature Hierarchy

Understanding what each layer of a CNN learns helps us make intelligent decisions about what to freeze and what to fine-tune.

Layer-by-Layer Feature Analysis

Research by Zeiler & Fergus (2014) and others has shown that CNN layers learn increasingly abstract features:

Layer DepthFeatures LearnedTransferabilityAction
Conv 1-2Edges, gradients, colors, simple textures~95% (highly universal)Almost always freeze
Conv 3-4Corners, contours, complex textures~80% (mostly universal)Usually freeze
Conv 5-7Object parts, semantic patterns~60% (domain-dependent)Fine-tune if enough data
Conv 8+Whole objects, high-level concepts~30% (task-specific)Often fine-tune or replace
FC layersClass-specific decision boundaries~0% (ImageNet-specific)Always replace

Mathematical Perspective

We can think of a pretrained CNN as a composition of learned functions:

fpretrained(x)=fclassifierfhighfmidflow(x)f_{\text{pretrained}}(\mathbf{x}) = f_{\text{classifier}} \circ f_{\text{high}} \circ f_{\text{mid}} \circ f_{\text{low}}(\mathbf{x})

Where:

  • flowf_{\text{low}}: Early convolutional layers — highly transferable edge/texture detectors
  • fmidf_{\text{mid}}: Middle layers — moderately transferable part detectors
  • fhighf_{\text{high}}: Later layers — domain-specific semantic features
  • fclassifierf_{\text{classifier}}: Final fully-connected layers — task-specific, must be replaced

For transfer learning, we keep flowf_{\text{low}} fixed, optionally fine-tune fmidf_{\text{mid}} and fhighf_{\text{high}}, and replace fclassifierf_{\text{classifier}} entirely.


Interactive: Feature Transfer

Explore how CNN features transfer across different domains. Select a target domain and see how the transferability of each layer changes:

Loading interactive demo...

Quick Check

You're building a satellite imagery classifier. Based on the feature hierarchy, which layers would benefit most from fine-tuning?


Transfer Learning Strategies

The right strategy depends on two key factors: dataset size and domain similarity. Let's analyze each approach:

Strategy 1: Feature Extraction

When: Small dataset (<5K images) or very similar domain

🐍feature_extraction.py
1import torch.nn as nn
2from torchvision import models
3
4# Load pretrained ResNet
5model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
6
7# Freeze ALL backbone parameters
8for param in model.parameters():
9    param.requires_grad = False
10
11# Replace classifier head
12num_features = model.fc.in_features
13model.fc = nn.Sequential(
14    nn.Dropout(0.5),
15    nn.Linear(num_features, 256),
16    nn.ReLU(),
17    nn.Dropout(0.3),
18    nn.Linear(256, num_classes)
19)
20
21# Only head parameters are trainable
22optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

Strategy 2: Discriminative Fine-Tuning

When: Medium dataset (5K-50K images), similar domain

🐍discriminative_fine_tuning.py
1# Different learning rates for different depths
2# Key insight: earlier = smaller LR, later = larger LR
3
4param_groups = [
5    # Earliest layers: almost frozen
6    {'params': model.conv1.parameters(), 'lr': 1e-6},
7    {'params': model.bn1.parameters(), 'lr': 1e-6},
8    {'params': model.layer1.parameters(), 'lr': 1e-6},
9
10    # Early-mid layers: very small updates
11    {'params': model.layer2.parameters(), 'lr': 5e-6},
12
13    # Mid layers: moderate updates
14    {'params': model.layer3.parameters(), 'lr': 1e-5},
15
16    # Late layers: larger updates
17    {'params': model.layer4.parameters(), 'lr': 5e-5},
18
19    # New head: normal learning rate
20    {'params': model.fc.parameters(), 'lr': 1e-3},
21]
22
23optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)

Strategy 3: Gradual Unfreezing

When: Any dataset size, safest approach

🐍gradual_unfreezing.py
1def get_layer_groups(model):
2    """Get ResNet layer groups from shallow to deep."""
3    return [
4        [model.conv1, model.bn1],
5        [model.layer1],
6        [model.layer2],
7        [model.layer3],
8        [model.layer4],
9        [model.fc]
10    ]
11
12def unfreeze_layer_group(model, group_idx, layer_groups):
13    """Unfreeze a specific layer group."""
14    for param in layer_groups[group_idx].parameters():
15        param.requires_grad = True
16
17# Training schedule
18layer_groups = get_layer_groups(model)
19
20# Phase 1: Only head (5 epochs, lr=1e-3)
21for group in layer_groups[:-1]:
22    for param in group.parameters():
23        param.requires_grad = False
24train(model, epochs=5, lr=1e-3)
25
26# Phase 2: Unfreeze layer4 (5 epochs, lr=1e-4)
27unfreeze_layer_group(model, 4, layer_groups)
28train(model, epochs=5, lr=1e-4)
29
30# Phase 3: Unfreeze layer3 (10 epochs, lr=1e-5)
31unfreeze_layer_group(model, 3, layer_groups)
32train(model, epochs=10, lr=1e-5)

The 10x Learning Rate Rule

When fine-tuning, use a learning rate 10-100x smaller than you would for training from scratch. This prevents catastrophic forgetting of pretrained features.

Interactive: Strategy Decider

Use this tool to determine the best transfer learning strategy for your situation. Adjust the sliders to match your dataset characteristics:

Loading interactive demo...


Pretrained Model Zoo

PyTorch's torchvision.models provides a rich collection of pretrained CNN architectures. Choosing the right one depends on your constraints:

Model Families Overview

FamilyKey InnovationBest ForTrade-off
ResNetSkip connectionsGeneral purpose, well-studiedBalanced accuracy/speed
VGGSimple stacked convsFeature visualization, style transferLarge, slower
DenseNetDense connectivityParameter efficiencyMemory intensive
EfficientNetCompound scalingMobile/edge, high accuracyBest efficiency curve
ConvNeXtModernized ResNetCompeting with ViTExcellent performance
Vision TransformerSelf-attentionVery large datasetsNeeds more data to transfer well

Quick Selection Guide

  • Quick prototyping: ResNet-18 (11M params, fast)
  • Production baseline: ResNet-50 (25M params, reliable)
  • Mobile deployment: EfficientNet-B0 (5M params, efficient)
  • Maximum accuracy: EfficientNet-B4 or ConvNeXt-Base
  • Research/large data: ViT-B/16 (86M params, scales well)

Interactive: Model Comparison

Compare pretrained models across accuracy, parameters, and computational cost. Select models to see detailed usage instructions:

Loading interactive demo...


Complete Implementation

Here's a production-ready transfer learning pipeline that incorporates all best practices:

Production Transfer Learning Pipeline
🐍transfer_learning_pipeline.py
8Model Selection

Choose a pretrained architecture. ResNet-50 is a good default choice with strong performance and well-understood transfer behavior.

13Architecture-Specific Extraction

Different architectures have different final layer names. ResNet uses fc, VGG uses classifier, EfficientNet uses classifier, etc.

21Custom Head Design

Replace the ImageNet classifier (1000 classes) with a custom head. Adding dropout helps prevent overfitting on small datasets.

34Freeze Backbone

Setting requires_grad=False prevents gradient computation for pretrained layers. This is feature extraction mode.

40Discriminative Learning Rates

Use different learning rates for different layer groups. Early layers need tiny updates (1e-6), later layers can adapt more (1e-4).

54Preprocessing Pipeline

CRITICAL: Use the same normalization the pretrained model was trained with. ImageNet models expect mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225].

103 lines without explanation
1import torch
2import torch.nn as nn
3from torchvision import models, transforms
4from torch.optim import AdamW
5from torch.optim.lr_scheduler import CosineAnnealingLR
6
7class TransferLearningPipeline:
8    ARCHITECTURES = {
9        'resnet18': (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1, 'fc'),
10        'resnet50': (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2, 'fc'),
11        'efficientnet_b0': (models.efficientnet_b0, models.EfficientNet_B0_Weights.IMAGENET1K_V1, 'classifier'),
12        'convnext_tiny': (models.convnext_tiny, models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1, 'classifier'),
13    }
14
15    def __init__(self, architecture, num_classes, strategy='feature_extraction'):
16        model_fn, weights, head_name = self.ARCHITECTURES[architecture]
17
18        self.model = model_fn(weights=weights)
19        self.head_name = head_name
20        self.strategy = strategy
21
22        # Replace classification head
23        self._replace_head(num_classes)
24
25        # Apply strategy
26        if strategy == 'feature_extraction':
27            self._freeze_backbone()
28        elif strategy == 'discriminative':
29            pass  # All trainable with different LRs
30
31    def _replace_head(self, num_classes):
32        """Replace classifier with custom head."""
33        old_head = getattr(self.model, self.head_name)
34
35        if isinstance(old_head, nn.Linear):
36            in_features = old_head.in_features
37        else:  # Sequential (EfficientNet, ConvNeXt)
38            in_features = old_head[-1].in_features
39
40        new_head = nn.Sequential(
41            nn.Dropout(0.4),
42            nn.Linear(in_features, 512),
43            nn.ReLU(inplace=True),
44            nn.Dropout(0.2),
45            nn.Linear(512, num_classes)
46        )
47        setattr(self.model, self.head_name, new_head)
48
49    def _freeze_backbone(self):
50        """Freeze all parameters except classification head."""
51        for name, param in self.model.named_parameters():
52            if self.head_name not in name:
53                param.requires_grad = False
54
55    def get_optimizer(self, base_lr=1e-3):
56        """Get optimizer with strategy-appropriate learning rates."""
57        if self.strategy == 'feature_extraction':
58            params = getattr(self.model, self.head_name).parameters()
59            return AdamW(params, lr=base_lr, weight_decay=0.01)
60
61        # Discriminative learning rates
62        param_groups = self._get_layer_groups(base_lr)
63        return AdamW(param_groups, weight_decay=0.01)
64
65    def _get_layer_groups(self, base_lr):
66        """Create parameter groups with discriminative LRs."""
67        # Layer groups from shallow to deep (ResNet-style)
68        groups = []
69
70        for name, param in self.model.named_parameters():
71            if not param.requires_grad:
72                continue
73
74            # Determine LR based on layer depth
75            if 'layer1' in name or 'conv1' in name:
76                lr = base_lr * 0.001
77            elif 'layer2' in name:
78                lr = base_lr * 0.01
79            elif 'layer3' in name:
80                lr = base_lr * 0.1
81            elif 'layer4' in name:
82                lr = base_lr * 0.5
83            else:  # Head
84                lr = base_lr
85
86            groups.append({'params': [param], 'lr': lr})
87
88        return groups
89
90    @staticmethod
91    def get_transforms(architecture):
92        """Get appropriate transforms for architecture."""
93        _, weights, _ = TransferLearningPipeline.ARCHITECTURES[architecture]
94        return weights.transforms()
95
96# Usage example
97pipeline = TransferLearningPipeline(
98    architecture='resnet50',
99    num_classes=10,
100    strategy='discriminative'
101)
102
103optimizer = pipeline.get_optimizer(base_lr=1e-3)
104scheduler = CosineAnnealingLR(optimizer, T_max=50)
105transforms = pipeline.get_transforms('resnet50')
106
107# Training loop
108model = pipeline.model
109model.train()

Domain-Specific Transfer

Different domains require different transfer strategies. Here's guidance for common specialized applications:

Medical Imaging

ConsiderationRecommendation
Domain gapModerate - X-rays/CT differ from natural images but share structural features
Data availabilityUsually small (hundreds to low thousands)
StrategyFeature extraction or careful gradual unfreezing
PreprocessingGrayscale images need 3-channel conversion; consider CLAHE
RegularizationHeavy dropout (0.5+), strong augmentation
🐍medical_imaging_transfer.py
1# Medical imaging specific preprocessing
2medical_transforms = transforms.Compose([
3    transforms.Resize(256),
4    transforms.CenterCrop(224),
5    transforms.Grayscale(num_output_channels=3),  # Convert to RGB
6    transforms.ToTensor(),
7    transforms.Normalize(
8        mean=[0.485, 0.456, 0.406],  # Still use ImageNet stats
9        std=[0.229, 0.224, 0.225]
10    ),
11])
12
13# Consider specialized pretrained models
14# RadImageNet, CheXNet, or DenseNet-121 pretrained on chest X-rays

Satellite/Aerial Imagery

ConsiderationRecommendation
Domain gapModerate - different viewpoint, but textures/edges transfer
Input channelsMay have >3 channels (infrared, multispectral)
StrategyPartial fine-tuning usually works well
ResolutionOften higher resolution than ImageNet; use larger input sizes

Document/Text Images

ConsiderationRecommendation
Domain gapLarge - very different from natural images
StrategyFeature extraction may struggle; consider full fine-tuning
AlternativeConsider specialized models (document-pretrained) if available
InputBinary/grayscale; ensure proper normalization

Domain Gap Awareness

When the domain gap is large (like documents vs natural images), transfer learning may provide minimal benefit. In such cases:
  1. Try transfer learning first as a baseline
  2. If results are poor, consider training from scratch with more data
  3. Search for domain-specific pretrained models
  4. Use self-supervised pretraining on your unlabeled domain data

Best Practices

1. Always Match Preprocessing

The most common source of poor transfer learning results is mismatched preprocessing:

🐍preprocessing_match.py
1# CORRECT: Use weights-specific transforms
2from torchvision.models import ResNet50_Weights
3
4weights = ResNet50_Weights.IMAGENET1K_V2
5preprocess = weights.transforms()
6
7# Or manually specify ImageNet statistics
8preprocess = transforms.Compose([
9    transforms.Resize(256),
10    transforms.CenterCrop(224),
11    transforms.ToTensor(),
12    transforms.Normalize(
13        mean=[0.485, 0.456, 0.406],
14        std=[0.229, 0.224, 0.225]
15    )
16])
17
18# WRONG: Using different normalization
19# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Don't do this!

2. Handle BatchNorm Carefully

BatchNorm layers store running statistics from ImageNet. For very small datasets:

🐍batchnorm_handling.py
1# Option 1: Keep BatchNorm in eval mode during training
2def train_with_frozen_bn(model, dataloader, optimizer, criterion):
3    model.train()
4
5    # But keep BatchNorm layers in eval mode
6    for module in model.modules():
7        if isinstance(module, nn.BatchNorm2d):
8            module.eval()
9
10    # Training loop as usual...
11
12# Option 2: Freeze BatchNorm parameters
13for module in model.modules():
14    if isinstance(module, nn.BatchNorm2d):
15        module.weight.requires_grad = False
16        module.bias.requires_grad = False

3. Use Proper Learning Rate Scheduling

🐍lr_scheduling.py
1from torch.optim.lr_scheduler import (
2    CosineAnnealingLR,
3    OneCycleLR,
4    ReduceLROnPlateau
5)
6
7# Cosine annealing - smooth decay
8scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
9
10# One cycle - fast convergence
11scheduler = OneCycleLR(
12    optimizer,
13    max_lr=1e-3,
14    epochs=num_epochs,
15    steps_per_epoch=len(train_loader),
16    pct_start=0.3,  # Warmup for 30% of training
17)
18
19# Reduce on plateau - adaptive
20scheduler = ReduceLROnPlateau(
21    optimizer,
22    mode='min',
23    factor=0.5,
24    patience=5,
25    min_lr=1e-7
26)

4. Monitor for Catastrophic Forgetting

Watch for signs that pretrained features are being destroyed:

  • Training loss drops quickly but validation loss increases
  • Validation accuracy peaks early then degrades
  • Activations become saturated (all zeros or very large)

If you see these signs: reduce learning rate by 10x and try again.


Common Mistakes

Mistake 1: Wrong Input Size

🐍python
1# WRONG: Using arbitrary input size
2model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
3x = torch.randn(1, 3, 128, 128)  # Too small!
4
5# CORRECT: Use model's expected size (224 for most, 299 for Inception, 380 for EfficientNet-B4)
6x = torch.randn(1, 3, 224, 224)

Mistake 2: Forgetting to Replace the Head

🐍python
1# WRONG: Training with 1000-class output
2model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
3# Oops, model.fc outputs 1000 classes, not our num_classes!
4
5# CORRECT: Replace the classification head
6model.fc = nn.Linear(model.fc.in_features, num_classes)

Mistake 3: Training All Layers with Same LR

🐍python
1# WRONG: Same learning rate for everything
2optimizer = Adam(model.parameters(), lr=1e-3)  # Will destroy early features!
3
4# CORRECT: Lower LR for pretrained layers
5optimizer = Adam([
6    {'params': model.conv1.parameters(), 'lr': 1e-5},
7    {'params': model.fc.parameters(), 'lr': 1e-3}
8])

Mistake 4: Not Using Data Augmentation

With small datasets, data augmentation is critical to prevent overfitting:

🐍augmentation.py
1from torchvision import transforms
2
3train_transforms = transforms.Compose([
4    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
5    transforms.RandomHorizontalFlip(),
6    transforms.RandomRotation(15),
7    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
8    transforms.ToTensor(),
9    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
10    transforms.RandomErasing(p=0.1),  # Cutout-style augmentation
11])

Summary

ConceptKey Takeaway
Why CNNs transferHierarchical features: early layers learn universal visual primitives
Feature extractionFreeze backbone, train head only - best for small data (<5K)
Fine-tuningTrain all/some layers with low LR - best for medium data (5K-50K)
Discriminative LREarly layers: 1e-6, Late layers: 1e-4, Head: 1e-3
Gradual unfreezingSafest approach - start frozen, progressively unfreeze
PreprocessingMUST match pretrained model (ImageNet: mean=[0.485, 0.456, 0.406])
Model selectionResNet-50 for baseline, EfficientNet for efficiency, ConvNeXt for SOTA
BatchNormKeep in eval mode for very small datasets

Exercises

Conceptual Questions

  1. Explain why early CNN layers transfer better than later layers. What would you expect if you tried to transfer the final convolutional layer from a model trained on natural images to a document scanner classifier?
  2. A colleague trained a transfer learning model and found that training accuracy was 99% but validation accuracy was only 45%. What went wrong, and how would you fix it?
  3. You have 500 labeled medical X-ray images. Describe your transfer learning strategy, including which layers to freeze, learning rate choices, and what preprocessing you would use.

Coding Exercises

  1. Model comparison: Implement transfer learning with ResNet-18, ResNet-50, and EfficientNet-B0 on CIFAR-10. Compare training time, final accuracy, and inference speed. Use feature extraction mode.
  2. Strategy comparison: On a subset of 1,000 CIFAR-10 images, compare: (a) feature extraction, (b) discriminative fine-tuning, (c) gradual unfreezing. Plot learning curves for each.
  3. Learning rate experiment: Fine-tune ResNet-50 with learning rates [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]. Plot training and validation curves. Identify which rates cause catastrophic forgetting and which undertrain.

Exercise Tips

  • Exercise 1: Use torchvision.datasets.CIFAR10 with proper transforms. Resize to 224x224 for fair comparison.
  • Exercise 2: Subsample CIFAR-10 to 1000 images (100 per class). Track both train and val accuracy per epoch.
  • Exercise 3: 1e-2 should show forgetting (val drops after initial rise), 1e-6 should undertrain (slow improvement). Sweet spot is usually 1e-4 to 1e-5.

In the next section, we'll dive deeper into using pretrained models with specific code recipes for common tasks like image classification, feature extraction, and fine-tuning on custom datasets.