Chapter 14
15 min read
Section 46 of 65

Transfer Learning with Pretrained Models

CNN Architectures

Why Transfer Learning Works

In Section 2, we saw that ResNet trained on ImageNet achieves superhuman accuracy on 1,000 categories. But most real-world problems are not “classify these 1,000 ImageNet categories.” You might need to classify 5 types of flowers, 3 types of skin lesions, or 20 types of manufactured defects — tasks with far less training data than ImageNet's 1.2 million images.

Transfer learning solves this by reusing a model pretrained on a large dataset (like ImageNet) and adapting it to your specific task. Instead of training from random weights on 500 images, you start from weights that already understand visual concepts from 1.2M images.

The Core Idea: A CNN trained on ImageNet does not just learn to classify dogs and cars. It learns a hierarchy of visual features — edges, textures, shapes, parts, objects — that are useful for any visual task. Transfer learning reuses these universal features.

The Feature Hierarchy Insight

Research by Zeiler & Fergus (2014) showed that CNN layers learn increasingly abstract features. The deeper you go, the more task-specific the features become:

Layer GroupWhat It LearnsTransferability
Early layers (conv1–conv2)Edges, corners, color gradientsHighly universal — useful for ANY visual task
Middle layers (conv3–conv4)Textures, patterns, simple shapesVery transferable across most domains
Later layers (conv5–layer3)Object parts, complex shapesModerately transferable — may need fine-tuning
Final layers (layer4–fc)Task-specific compositionsLeast transferable — usually replaced

This gradient of transferability has a profound practical implication: the early layers of a pretrained CNN are a near-universal feature extractor. Whether your task involves flowers, medical images, satellite photos, or factory defects — the edge detectors, texture analyzers, and shape recognizers learned from ImageNet provide an excellent starting point.

Feature Hierarchy in CNNs

How neural networks build complex features from simple ones

🖼️
Input
Raw pixels
RGB valuesGrayscale intensity
Level 0
Concrete
📐
Layer 1
Edges & Gradients
Horizontal edgesVertical edgesDiagonal linesColor blobs
Level 1
🔲
Layer 2
Textures & Patterns
CornersSimple texturesGradientsColor patterns
Level 2
👁️
Layer 3
Object Parts
EyesWheelsFur patternsWindows
Level 3
🎯
Layer 4+
Objects & Scenes
FacesCarsAnimalsBuildings
Level 4
Abstract

Key Insight: Each layer combines features from the previous layer. Early layers detect low-level features; deeper layers capture high-level concepts.

The diagram above is the pattern Zeiler & Fergus (2014) documented experimentally: early layers respond to Gabor-like edges and colour blobs, middle layers respond to textures and simple parts, later layers respond to whole objects. Yosinski et al. (2014) quantified the effect: freezing the first layer costs almost no accuracy on a new task, freezing the last layer costs a lot. This is the empirical basis of every transfer-learning strategy we are about to use.

The Mathematics of Why It Works

Formally, consider a pretrained model fθf_{\theta^*} with optimal ImageNet weights θ\theta^*. For a new task with a small dataset DnewD_{\text{new}}, training from scratch yields weights θscratch\theta_{\text{scratch}} that likely overfit. But starting from θ\theta^* and fine-tuning gives θft\theta_{\text{ft}} that:

  1. Starts in a good region of the loss landscape (pretrained features are already useful)
  2. Needs fewer gradient steps to converge (5 epochs instead of 100)
  3. Generalizes better because the pretrained features provide implicit regularization (the features were validated on 1.2M diverse images)

Two Strategies: Extract vs Fine-Tune

There are two main approaches to transfer learning, and the right choice depends on your dataset size and how similar your domain is to ImageNet:

Strategy 1: Feature Extraction

Freeze the entire pretrained backbone. Only train a new classification head. The pretrained model acts as a fixed feature extractor.

  • When to use: Small dataset (<1,000 images per class), or when your domain is similar to ImageNet (natural photos of objects, animals, scenes)
  • Pros: Fast training, no risk of overfitting, minimal compute
  • Cons: Cannot adapt features to domain-specific patterns

Strategy 2: Fine-Tuning

Unfreeze some or all backbone layers and train them with a small learning rate. The pretrained features are gently adapted to your domain.

  • When to use: Medium-to-large dataset (>1,000 images per class), or when your domain is different from ImageNet (medical images, aerial photos, microscopy)
  • Pros: Higher accuracy, adapts features to your domain
  • Cons: Risk of overfitting if dataset is too small, slower training, needs careful learning rate tuning
FactorFeature ExtractionFine-Tuning
Trainable parametersOnly new head (~2K)All layers (~11M)
Training timeMinutesHours
Min dataset size~100 images/class~500–1000 images/class
Risk of overfittingVery lowModerate (need regularization)
Accuracy ceilingGood (95%+)Best possible (98%+)
Learning rateNormal (1e-3)Differential (1e-4 backbone, 1e-3 head)

Quick Check

You have 50 images per class and your images are regular photos of animals. Which transfer learning strategy should you use?


Loading Pretrained Models

PyTorch provides a rich model zoo through torchvision.models\texttt{torchvision.models}. Let's load a pretrained ResNet-18 and examine its structure:

Loading a Pretrained ResNet-18
🐍load_pretrained.py
1import torch

PyTorch core library.

2import torch.nn as nn

Neural network module for modifying layers.

3from torchvision import models

torchvision.models provides pretrained CNN architectures: ResNet, VGG, EfficientNet, MobileNet, and many more. These are not just architecture definitions — they come with weights trained on ImageNet’s 1.2M images.

EXECUTION STATE
📚 torchvision.models = Pre-built model zoo. Includes: resnet18/34/50/101/152, vgg16/19, efficientnet_b0-b7, mobilenet_v3, inception_v3, and 50+ other architectures.
4from torchvision.models import ResNet18_Weights

The modern PyTorch API uses typed weight enums instead of the deprecated weights=True. ResNet18_Weights.IMAGENET1K_V1 specifies exactly which pretrained weights to load, including the required preprocessing transforms.

EXECUTION STATE
📚 ResNet18_Weights = Enum of available weight versions. IMAGENET1K_V1 = original weights trained on ImageNet-1K. V2 = improved training recipe (higher accuracy). DEFAULT = latest best weights.
7model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

Download and load ResNet-18 with pretrained ImageNet weights. This model has been trained for ~90 epochs on 1.2M images across 1,000 classes (dogs, cars, birds, furniture, food, etc.). Every conv filter, batch norm parameter, and FC weight is loaded from the pretrained checkpoint.

EXECUTION STATE
📚 models.resnet18(weights) = Creates a ResNet-18 model (18 layers: 1 initial conv + 8 residual blocks × 2 convs + 1 FC) and loads pretrained weights. First call downloads ~45MB of weights; subsequent calls use a local cache.
⬇ arg: weights = ResNet18_Weights.IMAGENET1K_V1 — pretrained on ImageNet-1K (2012 dataset, 1.2M train images, 1000 classes). Top-1 accuracy: 69.8%. Top-5: 89.1%.
→ model = A fully trained ResNet-18 with 11.7M parameters. Every weight has been optimized to extract useful visual features.
10print(model)

Print the full model architecture. ResNet-18 consists of: an initial 7×7 conv with stride 2, a max pool, four groups of residual blocks (layer1–layer4), global average pooling, and a 512→1000 FC layer.

EXECUTION STATE
conv1 = Conv2d(3, 64, 7, stride=2, padding=3) — initial layer maps 3 RGB channels to 64 features. Stride 2 halves the 224×224 input to 112×112.
layer1–layer4 = Four groups of 2 BasicBlocks each. Channels: 64 → 128 → 256 → 512. Spatial size: 56 → 28 → 14 → 7.
avgpool = AdaptiveAvgPool2d(1) — global average pooling. 7×7 → 1×1 per channel.
fc = Linear(512, 1000) — 512 features to 1000 ImageNet classes.
24print(f"Original FC: {model.fc}")

The final FC layer outputs 1000 classes (ImageNet). For transfer learning, we replace this layer with one that matches our target task (e.g., 10 classes for digits, 5 classes for flowers).

EXECUTION STATE
model.fc = Linear(in_features=512, out_features=1000, bias=True) — this is the ONLY layer we typically need to replace for a new task.
28total = sum(p.numel() for p in model.parameters())

ResNet-18 has 11.7M parameters — all of which encode visual knowledge learned from 1.2M diverse images. These features (edges, textures, shapes, object parts) are useful for virtually any visual task.

EXECUTION STATE
total = 11,689,512 parameters. By comparison: LeNet-5 had 44K, our CNN had 207K. The pretrained features represent months of GPU training that we get for free.
23 lines without explanation
1import torch
2import torch.nn as nn
3from torchvision import models
4from torchvision.models import ResNet18_Weights
5
6# Load ResNet-18 pretrained on ImageNet (1.2M images, 1000 classes)
7model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
8
9# Inspect the model structure
10print(model)
11# ResNet(
12#   (conv1): Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
13#   (bn1): BatchNorm2d(64)
14#   (relu): ReLU(inplace=True)
15#   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1)
16#   (layer1): Sequential(2 x BasicBlock)
17#   (layer2): Sequential(2 x BasicBlock)
18#   (layer3): Sequential(2 x BasicBlock)
19#   (layer4): Sequential(2 x BasicBlock)
20#   (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
21#   (fc): Linear(in_features=512, out_features=1000)
22# )
23
24# Check the final FC layer
25print(f"Original FC: {model.fc}")
26# Linear(in_features=512, out_features=1000, bias=True)
27
28# Total parameters: 11.7M
29total = sum(p.numel() for p in model.parameters())
30print(f"Total parameters: {total:,}")
31# Total parameters: 11,689,512
The modern PyTorch API (v0.13+) uses typed weight enums like ResNet18_Weights.IMAGENET1K_V1\texttt{ResNet18\_Weights.IMAGENET1K\_V1} instead of the old pretrained=True\texttt{pretrained=True}. The enum approach is safer because it specifies exactly which weights and preprocessing transforms to use.

Strategy 1: Feature Extraction

The simplest form of transfer learning: freeze the pretrained backbone, replace the classification head, and train only the new head. Three lines of code transform a 1000-class ImageNet model into a 5-class flower classifier:

Transfer Learning: Feature Extraction
🐍feature_extraction.py
1import torch

PyTorch core library.

2import torch.nn as nn

Neural network module — we need nn.Linear to create the new classification head.

3from torchvision import models, transforms, datasets

models provides pretrained architectures, transforms provides image preprocessing, and datasets provides data loading utilities including ImageFolder for custom datasets organized in folder-per-class structure.

EXECUTION STATE
📚 datasets.ImageFolder = Loads images from a directory structure where each subfolder is a class: data/flowers/train/rose/*.jpg, data/flowers/train/daisy/*.jpg, etc. Automatically assigns labels based on folder names.
8model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

Load ResNet-18 with ImageNet-pretrained weights. All 11.7M parameters contain learned visual features.

11for param in model.parameters():

Iterate over ALL learnable parameters in the model (conv weights, batch norm gammas/betas, FC weights and biases).

EXECUTION STATE
📚 model.parameters() = Generator yielding every Parameter tensor in the model. For ResNet-18: 60 parameter tensors totaling 11.7M values.
12param.requires_grad = False

FREEZE this parameter: tell PyTorch not to compute gradients for it. During backpropagation, frozen parameters are skipped, which saves memory and computation. The pretrained features remain exactly as they were trained on ImageNet.

EXECUTION STATE
📚 requires_grad = False = Disables gradient tracking for this parameter. It will not be updated by optimizer.step(). This is how we ‘freeze’ the pretrained backbone — use it as a fixed feature extractor.
→ effect = Only newly created parameters (like the replacement FC layer) will have requires_grad=True and be updated during training.
15num_classes = 5

Our target task has 5 classes (e.g., 5 types of flowers). This replaces the 1000 ImageNet classes.

16model.fc = nn.Linear(512, num_classes)

Replace the final FC layer. The original model.fc was Linear(512, 1000) for ImageNet. We swap it with Linear(512, 5) for our 5 flower classes. This new layer is created with random weights and requires_grad=True by default.

EXECUTION STATE
model.fc (before) = Linear(512, 1000) — pretrained for ImageNet classes. FROZEN.
model.fc (after) = Linear(512, 5) — randomly initialized for our 5 classes. TRAINABLE (requires_grad=True by default).
→ trainable params = 512 × 5 + 5 = 2,565 out of 11.7M total. We only train 0.02% of the network!
20transform = transforms.Compose([...])

Preprocessing pipeline that matches how the pretrained model was trained on ImageNet. Using different transforms would produce garbage results because the weights expect inputs normalized to specific statistics.

EXECUTION STATE
→ critical = The normalization mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] are the ImageNet dataset statistics. If you change these, the pretrained features will not work correctly.
21transforms.Resize(256)

Resize the shortest side of the image to 256 pixels, maintaining aspect ratio. This prepares for the 224×224 center crop.

EXECUTION STATE
📚 Resize(size) = If size is an int, resize the shorter edge to that length. A 640×480 image becomes 256×341.
22transforms.CenterCrop(224)

Crop a 224×224 square from the center. ResNet was trained on 224×224 inputs, so we must match this size exactly.

EXECUTION STATE
📚 CenterCrop(size) = Extracts the center portion of the image. From a 256×341 image, takes the central 224×224 pixels.
23transforms.ToTensor()

Convert PIL Image (H×W×C, 0-255) to PyTorch Tensor (C×H×W, 0.0-1.0).

24transforms.Normalize(mean, std)

Normalize each RGB channel using ImageNet’s statistics. This ensures pixel values match what the pretrained model expects.

EXECUTION STATE
⬇ mean = [0.485, 0.456, 0.406] = Average pixel intensity for R, G, B channels across the entire 1.2M-image ImageNet training set.
⬇ std = [0.229, 0.224, 0.225] = Standard deviation for R, G, B channels. After normalization: each channel has mean ≈ 0 and std ≈ 1.
31train_data = datasets.ImageFolder('data/flowers/train', transform=transform)

Load images from a folder-per-class directory structure. ImageFolder automatically discovers classes from subfolder names and assigns integer labels alphabetically.

EXECUTION STATE
📚 ImageFolder(root, transform) = Expects structure: root/class1/*.jpg, root/class2/*.jpg, etc. Automatically creates .classes list and .class_to_idx mapping.
→ example structure = data/flowers/train/daisy/*.jpg, data/flowers/train/rose/*.jpg, data/flowers/train/sunflower/*.jpg, etc.
32train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

Create batches of 32 images. With frozen features, we can use larger batches since only 2,565 FC parameters need gradient memory.

35criterion = nn.CrossEntropyLoss()

Same loss function as before — cross-entropy for multi-class classification.

36optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

Only optimize the FC layer’s parameters! We pass model.fc.parameters() instead of model.parameters(). This means the optimizer only tracks the 2,565 new FC weights — the other 11.7M parameters remain frozen.

EXECUTION STATE
⬇ arg: model.fc.parameters() = Only the replacement FC layer’s weight (512×5) and bias (5). Total: 2,565 trainable parameters.
→ comparison = model.parameters() = 11.7M params (all). model.fc.parameters() = 2,565 params (just the head). Training is ~4500× faster!
38for epoch in range(5):

Train for 5 epochs. Since we only update the FC layer and the backbone features are already excellent, convergence is very fast. Even 2–3 epochs often suffice.

LOOP TRACE · 2 iterations
epoch=0 (Epoch 1)
state = The pretrained features are already discriminative. Even random FC weights get ~20%, and after 1 epoch we hit ~82%.
epoch=4 (Epoch 5)
state = FC layer has converged. 96.5% accuracy on 5-class flowers with only 2,565 trained parameters.
40correct, total = 0, 0

Initialize counters for computing training accuracy.

41for images, labels in train_loader:

Iterate over batches of 32 flower images and their labels.

EXECUTION STATE
images = [32, 3, 224, 224] — batch of 32 RGB images, preprocessed to match ImageNet statistics
labels = [32] — integer class labels (0=daisy, 1=dandelion, 2=rose, 3=sunflower, 4=tulip)
42outputs = model(images)

Forward pass through the entire ResNet. The frozen backbone extracts rich features, then our trainable FC layer classifies them. The backbone acts as a powerful fixed feature extractor.

EXECUTION STATE
outputs = [32, 5] — 5 logit scores per image (one per flower class)
→ compute path = images → conv1(frozen) → ... → layer4(frozen) → avgpool(frozen) → fc(TRAINABLE) → outputs
43loss = criterion(outputs, labels)

Compute cross-entropy loss between predicted logits and true labels.

44optimizer.zero_grad()

Clear gradients of the FC layer parameters only (the optimizer only tracks model.fc).

45loss.backward()

Backpropagate. Gradients are computed only for model.fc (the only layer with requires_grad=True). The frozen backbone parameters have no gradients computed — saving significant memory and compute.

46optimizer.step()

Update model.fc.weight and model.fc.bias based on the computed gradients. Only 2,565 values are modified.

31 lines without explanation
1import torch
2import torch.nn as nn
3from torchvision import models, transforms, datasets
4from torchvision.models import ResNet18_Weights
5from torch.utils.data import DataLoader
6
7# Step 1: Load pretrained ResNet-18
8model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
9
10# Step 2: FREEZE all pretrained layers
11for param in model.parameters():
12    param.requires_grad = False   # No gradient updates
13
14# Step 3: Replace the final FC layer for our task
15num_classes = 5   # e.g., 5 types of flowers
16model.fc = nn.Linear(512, num_classes)
17# Only model.fc has requires_grad=True (newly created)
18
19# Step 4: Set up data transforms (must match ImageNet preprocessing)
20transform = transforms.Compose([
21    transforms.Resize(256),
22    transforms.CenterCrop(224),
23    transforms.ToTensor(),
24    transforms.Normalize(
25        mean=[0.485, 0.456, 0.406],
26        std=[0.229, 0.224, 0.225]
27    )
28])
29
30# Step 5: Create dataset and loader
31train_data = datasets.ImageFolder('data/flowers/train', transform=transform)
32train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
33
34# Step 6: Train ONLY the new FC layer
35criterion = nn.CrossEntropyLoss()
36optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
37
38for epoch in range(5):
39    model.train()
40    correct, total = 0, 0
41    for images, labels in train_loader:
42        outputs = model(images)
43        loss = criterion(outputs, labels)
44        optimizer.zero_grad()
45        loss.backward()
46        optimizer.step()
47        _, predicted = outputs.max(1)
48        correct += predicted.eq(labels).sum().item()
49        total += labels.size(0)
50    print(f"Epoch {epoch+1}: Accuracy={100*correct/total:.1f}%")
51
52# Epoch 1: Accuracy=82.3%
53# Epoch 2: Accuracy=91.5%
54# Epoch 3: Accuracy=94.2%
55# Epoch 4: Accuracy=95.8%
56# Epoch 5: Accuracy=96.5%

The key result: 96.5% accuracy on 5 flower classes by training only 2,565 parameters. The pretrained backbone provides such rich features that a simple linear classifier on top achieves excellent performance. We trained 0.02% of the model and got 96.5% accuracy.


Strategy 2: Fine-Tuning

When you have more data and want to squeeze out every last percentage point of accuracy, fine-tuning adapts the pretrained features to your specific domain. The key technique is differential learning rates: a small rate for the pretrained backbone (preserve features) and a larger rate for the new head (learn fast).

Transfer Learning: Fine-Tuning
🐍fine_tuning.py
1import torch

PyTorch core.

7model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

Load pretrained ResNet-18, same as before.

10num_classes = 5

Our target task: 5 flower classes.

11model.fc = nn.Linear(512, num_classes)

Replace the classification head. This new layer starts with random weights.

16backbone_params = [p for name, p in model.named_parameters() if 'fc' not in name]

Collect all parameters EXCEPT the FC layer. These are the pretrained convolutional features that we want to update gently.

EXECUTION STATE
📚 model.named_parameters() = Yields (name, parameter) tuples. Names like ‘conv1.weight’, ‘layer1.0.bn1.weight’, ‘fc.weight’, etc. We filter out anything containing ‘fc’.
backbone_params = List of ~58 parameter tensors totaling ~11.7M values. These contain the pretrained visual features.
18head_params = model.fc.parameters()

The FC layer parameters: weight (512×5) and bias (5). These are randomly initialized and need aggressive learning.

20optimizer = torch.optim.Adam([{...}, {...}])

Create an optimizer with DIFFERENT learning rates for different parameter groups. This is the key technique for fine-tuning: the pretrained backbone gets a small learning rate (gentle updates that preserve learned features), while the new head gets a normal learning rate (fast learning from scratch).

EXECUTION STATE
📚 parameter groups = Adam accepts a list of dicts, each with ‘params’ and optional ‘lr’, ‘weight_decay’, etc. Each group can have different hyperparameters.
21{'params': backbone_params, 'lr': 1e-4}

Backbone learning rate: 0.0001. This is 10× smaller than the head. Small updates refine the pretrained features without destroying what ImageNet training learned.

EXECUTION STATE
lr = 1e-4 = 0.0001 — gentle enough to preserve pretrained feature quality. Too large and you erase the pretrained knowledge. Too small and you do not adapt to the new domain.
22{'params': head_params, 'lr': 1e-3}

Head learning rate: 0.001 (10× larger). The new FC layer starts from random weights and needs to learn quickly.

EXECUTION STATE
lr = 1e-3 = 0.001 — standard Adam learning rate. The head has no pretrained knowledge to preserve, so a normal learning rate is appropriate.
25# Step 1: Warm up the head

Best practice: first train only the new FC layer for a few epochs. This prevents random FC gradients from corrupting pretrained backbone features. Once the head has reasonable weights, then unfreeze the backbone.

26for param in backbone_params: param.requires_grad = False

Freeze backbone during warmup. Only the FC layer is updated.

33# Step 2: Unfreeze and fine-tune everything

After the head is warmed up, unfreeze the entire model. Now gradients flow through all layers, and the backbone features are gently adapted to the new domain.

34for param in backbone_params: param.requires_grad = True

Unfreeze backbone parameters. Now the optimizer will update all 11.7M parameters, but with the small learning rate (1e-4) for the backbone and the larger rate (1e-3) for the head.

EXECUTION STATE
→ effect = The pretrained features are now fine-tuned: early layers (edge detectors) barely change, while later layers (high-level features) adapt more to the new domain.
40# Results

Fine-tuning reaches ~98% accuracy — higher than the 96.5% from pure feature extraction. The extra 1.5% comes from adapting the backbone features to the specific characteristics of flower images.

EXECUTION STATE
→ comparison = Feature extraction: 96.5% (only FC trained). Fine-tuning: 98.0% (all layers adapted). Fine-tuning is worth it when you have enough data (hundreds+ images per class).
31 lines without explanation
1import torch
2import torch.nn as nn
3from torchvision import models
4from torchvision.models import ResNet18_Weights
5
6# Load pretrained model
7model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
8
9# Replace final layer for our task
10num_classes = 5
11model.fc = nn.Linear(512, num_classes)
12
13# Strategy: different learning rates for different parts
14# Backbone: small lr (gentle updates to pretrained features)
15# Head: large lr (learn new classification from scratch)
16backbone_params = [p for name, p in model.named_parameters()
17                   if 'fc' not in name]
18head_params = model.fc.parameters()
19
20optimizer = torch.optim.Adam([
21    {'params': backbone_params, 'lr': 1e-4},   # 10x smaller
22    {'params': head_params,     'lr': 1e-3},   # normal lr
23])
24
25# Step 1: Warm up the head (2 epochs with frozen backbone)
26for param in backbone_params:
27    param.requires_grad = False
28
29for epoch in range(2):
30    # ... train only FC layer (same loop as feature extraction)
31    print(f"Warmup {epoch+1}: training only FC layer")
32
33# Step 2: Unfreeze and fine-tune everything (3 epochs)
34for param in backbone_params:
35    param.requires_grad = True
36
37for epoch in range(3):
38    # ... train entire model with differential lr
39    print(f"Fine-tune {epoch+1}: training all layers")
40
41# Warmup 1: training only FC layer    -> ~82% accuracy
42# Warmup 2: training only FC layer    -> ~91% accuracy
43# Fine-tune 1: training all layers    -> ~95% accuracy
44# Fine-tune 2: training all layers    -> ~97% accuracy
45# Fine-tune 3: training all layers    -> ~98% accuracy

The Warmup + Fine-Tune Pattern

The two-phase approach is critical for stable fine-tuning:

  1. Phase 1 (Warmup): Freeze backbone, train only the new FC head for 2\u20133 epochs. This gives the head reasonable weights so that its gradients are meaningful when they flow back through the backbone.
  2. Phase 2 (Fine-tune): Unfreeze backbone with a small learning rate (10\u00d7 smaller than head). Train all layers together. The backbone features adapt gently to the new domain.
Without warmup, the randomly initialized head produces chaotic gradients that propagate through the backbone and corrupt the pretrained features. This can actually make the model worse than feature extraction alone. Always warm up the head first.

BatchNorm Under Fine-Tuning

Section 2 introduced BatchNorm and its two modes. Here is where that distinction becomes a production gotcha. A pretrained ResNet carries running statistics learned from ImageNet — the mean and variance of every channel across 1.2 M natural images. Those statistics are not learnable parameters, so param.requires_grad = False does nothing to them. They keep drifting whenever the layer is in training mode and the optimiser can never reset them.

The symptom: you freeze the backbone, train the head, achieve 95% validation accuracy, then unfreeze for fine-tuning — and accuracy drops. You check gradients, losses, learning rates. They look fine. What changed? The BatchNorm running stats were silently being replaced by statistics from your much smaller fine-tuning batches, which are noisier and have a different distribution than ImageNet. The feature extractor's normalisation assumption is now broken.

The Fix: Freeze BN Stats Explicitly

🐍python
1def freeze_bn_running_stats(module: nn.Module) -> None:
2    """Keep BatchNorm layers in eval() mode even when the overall model is in train().
3
4    This pins running_mean / running_var to their pretrained (ImageNet) values
5    while still allowing gamma and beta to receive gradient updates if they are
6    not separately frozen.
7    """
8    for m in module.modules():
9        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
10            m.eval()                    # switch BN to running-stats mode
11            # Optional: also freeze the learnable gamma/beta if you want BN
12            # to behave as an entirely fixed, pretrained normalisation.
13            # for p in m.parameters():
14            #     p.requires_grad = False
15
16
17# Usage inside the fine-tuning loop:
18model.train()                           # enables dropout, sets the default mode
19freeze_bn_running_stats(model)          # OVERRIDE just the BN layers back to eval
20for x, y in train_loader:
21    optimizer.zero_grad()
22    out = model(x)                      # BN uses pretrained running stats, not batch stats
23    loss = criterion(out, y)
24    loss.backward()
25    optimizer.step()

The call order matters. model.train() sets every submodule to training mode. Then freeze_bn_running_stats walks the module tree and puts each BatchNorm layer back into eval() mode. The rest of the model (conv, linear, dropout) stays in training mode as intended.

When not to freeze BN stats. If your new dataset genuinely has a very different distribution from ImageNet (medical scans, satellite imagery, thermal cameras) then the pretrained running stats are wrong for you. In that case, do let BN adapt — but use a large fine-tuning batch size (at least 32) so the updated stats are not dominated by mini-batch noise. A common compromise is to freeze BN only in the early layers (where ImageNet edges still match) and let the later stages re-estimate their stats.

Practical Guidelines

Here is a decision flowchart for transfer learning in practice:

Your SituationRecommended ApproachLearning Rate
Very small data (<100/class), similar domainFeature extraction1e-3 (head only)
Small data (100–1000/class), similar domainFeature extraction1e-3 (head only)
Medium data (1000–10000/class), similar domainFine-tune last 1–2 layers1e-4 backbone, 1e-3 head
Large data (>10000/class), similar domainFine-tune all layers1e-4 backbone, 1e-3 head
Any size, very different domain (medical, satellite)Fine-tune all layers + augmentation1e-5 early, 1e-4 late, 1e-3 head
Huge data (millions), very different domainTrain from scratch (or pretrain on your domain)1e-3 everywhere

Common Pitfalls

  1. Wrong preprocessing: Always use the same normalization as the pretrained model. ImageNet models expect mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]. Using different values will destroy feature quality.
  2. Too-high backbone learning rate: If you fine-tune with lr=1e-2, you will erase the pretrained features in the first epoch. Use 1e-4 or smaller for the backbone.
  3. No warmup: Skipping the head warmup phase lets random gradients corrupt pretrained features. Always train the head for 1\u20132 epochs first.
  4. Forgetting model.eval(): During inference, you must call model.eval() to disable dropout and switch batch norm to running statistics. This is critical for pretrained models with batch norm.
  5. Grayscale input to RGB model: If your images are grayscale, repeat the channel 3 times: x = x.repeat(1, 3, 1, 1)\texttt{x = x.repeat(1, 3, 1, 1)}. The pretrained model expects 3 channels.
Chapter Summary: In this chapter, we built a complete CNN from scratch (Section 1), traced the evolution of CNN architectures from LeNet to ResNet (Section 2), and learned to leverage pretrained models through transfer learning (Section 3). The practical takeaway: almost never train a CNN from scratch. Start with a pretrained backbone, adapt it to your task, and achieve excellent results with a fraction of the data and compute.

References

The transferability claim — that early-layer features are near-universal and late-layer features are task-specific — rests on controlled experiments by Yosinski et al. (2014) and the visualisation work of Zeiler & Fergus (2013). They are the papers to cite for anything beyond informal intuition.

  • Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K. & Fei-Fei, L. (2009). ImageNet: A Large-Scale Hierarchical Image Database. CVPR 2009. — The 1.2M-image, 1000-class dataset the pretrained ResNet we load was trained on.
  • Zeiler, M. D. & Fergus, R. (2013). Visualizing and Understanding Convolutional Networks. ECCV 2014 / arXiv:1311.2901. — The classic layer-by-layer feature visualisation (Gabor-like filters in layer 1, object parts in layer 5).
  • Yosinski, J., Clune, J., Bengio, Y. & Lipson, H. (2014). How transferable are features in deep neural networks? NeurIPS 2014 / arXiv:1411.1792. — Quantifies which layers transfer and why.
  • He, K., Zhang, X., Ren, S. & Sun, J. (2015). Deep Residual Learning for Image Recognition. CVPR 2016 / arXiv:1512.03385. — The ResNet-18 weights we load.
  • Ioffe, S. & Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML 2015 / arXiv:1502.03167. — The source of the eval()\texttt{eval()} / running-stats caveat under fine-tuning.
  • PyTorch documentation. torchvision.models. pytorch.org/vision/stable/models.html — Weight enums and pretrained weights used in this section.
Loading comments...