Chapter 12
25 min read
Section 77 of 178

Using Pretrained Models

CNNs in Practice

Introduction

In the previous section, we explored the theory of transfer learning—how knowledge from one task can accelerate learning on another. Now we turn to the practical side: how do we actually use pretrained models in PyTorch?

The deep learning community has invested millions of GPU hours training powerful models on massive datasets like ImageNet. These pretrained models are freely available, and learning to use them effectively is one of the most valuable skills in modern deep learning.

The Practical Reality: Most production computer vision systems don't train from scratch. They leverage pretrained models, saving weeks of training time and achieving better results with less data.

By the end of this section, you'll know how to load any pretrained model, adapt it to your specific task, and choose the optimal training strategy based on your data and computational constraints.


Learning Objectives

After completing this section, you will be able to:

  1. Load pretrained models from torchvision, torch.hub, and the timm library
  2. Understand model architectures and identify which layers to modify for your task
  3. Implement feature extraction by freezing backbone weights and training only the classifier
  4. Implement fine-tuning with differential learning rates and gradual unfreezing
  5. Choose the right strategy based on dataset size, domain similarity, and computational budget
  6. Apply best practices for preprocessing, learning rates, and avoiding common pitfalls

Why This Matters

In industry, the ability to quickly adapt pretrained models to new domains is highly valued. A medical imaging startup, for instance, can leverage ImageNet-pretrained models to achieve excellent results on X-ray classification with just hundreds of labeled examples.

Why Use Pretrained Models?

Before diving into code, let's understand the compelling reasons to use pretrained models rather than training from scratch.

The Economics of Training

Training a state-of-the-art image classification model from scratch requires:

ResourceTraining from ScratchUsing Pretrained
GPU Hours100-1000+ hours1-10 hours
Training Data1M+ images100-10K images
Cloud Cost$1,000-$50,000+$10-$100
Engineering TimeWeeks to monthsHours to days
Expertise RequiredArchitecture design, hyperparameter tuningBasic transfer learning

The Feature Hierarchy Advantage

Pretrained models have already learned a rich hierarchy of visual features:

Edges, ColorsLayer 1TexturesLayer 2PartsLayer 3-4ObjectsLayer 5+\underbrace{\text{Edges, Colors}}_{\text{Layer 1}} \rightarrow \underbrace{\text{Textures}}_{\text{Layer 2}} \rightarrow \underbrace{\text{Parts}}_{\text{Layer 3-4}} \rightarrow \underbrace{\text{Objects}}_{\text{Layer 5+}}

These early features (edges, textures, shapes) are universal—they transfer to virtually any visual domain. Even if you're classifying satellite imagery or medical scans, the low-level features learned from natural images provide an excellent starting point.

Layer DepthFeatures LearnedTransferability
Early (1-2)Edges, color gradients, Gabor-like filtersUniversal (99%+ domains)
Middle (3-4)Textures, corners, simple shapesVery high (90%+ domains)
Late (5-6)Object parts, complex patternsHigh for similar domains
Final (FC)Task-specific combinationsLow (must retrain)

Mathematical Intuition

From an optimization perspective, pretrained weights provide a better initializationin the loss landscape. Let L(θ)\mathcal{L}(\theta) be the loss function and θ\theta^* be the optimal parameters. Starting from pretrained weights θpre\theta_{\text{pre}} vs. random initialization θrand\theta_{\text{rand}}:

θpreθθrandθ\|\theta_{\text{pre}} - \theta^*\| \ll \|\theta_{\text{rand}} - \theta^*\|

This means gradient descent has a much shorter path to travel, leading to faster convergence and often finding better local minima.

Quick Check

Which layer type from a pretrained ImageNet model is LEAST likely to transfer well to a medical X-ray classification task?


The ImageNet Revolution

To understand pretrained models, we need to appreciate their origin: ImageNet—a dataset that transformed computer vision.

What is ImageNet?

ImageNet is a large-scale visual recognition dataset created by Fei-Fei Li and her team at Stanford:

PropertyValue
Total Images14+ million
Classes21,841 categories (WordNet synsets)
ILSVRC Subset1.2 million images, 1000 classes
Image ResolutionVariable, typically 224×224 after preprocessing
AnnotationHuman-verified bounding boxes and labels

ILSVRC: The Competition That Changed Everything

The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) ran from 2010-2017 and drove dramatic improvements in image classification:

YearWinnerTop-5 ErrorKey Innovation
2010NEC-UIUC28.2%SIFT + Fisher vectors
2011XRCE25.8%Compressed Fisher vectors
2012AlexNet16.4%Deep CNNs + GPU training
2014VGGNet7.3%Very deep (19 layers), 3×3 convs
2014GoogLeNet6.7%Inception modules
2015ResNet3.6%Skip connections (152 layers)
2017SENet2.3%Squeeze-and-excitation blocks

Human-Level Performance

Human annotators achieve approximately 5% top-5 error on ImageNet. ResNet surpassed this in 2015, demonstrating that CNNs can exceed human performance on certain visual recognition tasks.

Why ImageNet Weights Transfer So Well

ImageNet pretraining works remarkably well across domains because:

  1. Diversity: 1000 classes spanning animals, objects, scenes, and textures force the model to learn general-purpose features
  2. Scale: 1.2 million images provide enough data to learn robust, non-spurious features
  3. Hierarchy: The category structure (from breeds to species to animals) encourages learning at multiple abstraction levels
  4. Community Effort: Decades of research have optimized architectures specifically for ImageNet, resulting in well-tuned, efficient models

The PyTorch Model Zoo

PyTorch provides access to dozens of pretrained models through torchvision.models. Let's explore what's available and how to choose the right model.

Model Zoo Comparison

ModelParams (M) GFLOPs Top-1 (%) Inference (ms) Efficiency
EfficientNet-B4
19
4.2
82.9
21.34.4 acc/M
ResNet-101
44.5
7.8
77.4
14.21.7 acc/M
EfficientNet-B0
5.3
0.4
77.1
5.814.5 acc/M
ResNet-50
25.6
4.1
76.1
8.33.0 acc/M
MobileNetV3-L
5.4
0.22
75.2
3.813.9 acc/M
DenseNet-121
8
2.9
74.4
11.59.3 acc/M
ResNet-34
21.8
3.7
73.3
6.53.4 acc/M
VGG-19
143.7
19.6
72.4
15.20.5 acc/M
VGG-16
138.4
15.5
71.6
12.80.5 acc/M
ResNet-18
11.7
1.8
69.8
4.26.0 acc/M
MobileNetV3-S
2.5
0.06
67.7
2.127.1 acc/M

Click column headers to sort. Efficiency score = Top-1 accuracy per million parameters. Higher is better for resource-constrained deployments.

Model Families Overview

ResNet Family

The ResNet (Residual Network) family introduced skip connections that enabled training of very deep networks. The key insight: learning F(x)+xF(x) + x is easier than learning H(x)H(x) directly.

🐍resnet_variants.py
1# ResNet variants - trade accuracy for speed
2models.resnet18()   # 11.7M params, 69.8% top-1
3models.resnet34()   # 21.8M params, 73.3% top-1
4models.resnet50()   # 25.6M params, 76.1% top-1 (most popular)
5models.resnet101()  # 44.5M params, 77.4% top-1
6models.resnet152()  # 60.2M params, 78.3% top-1

EfficientNet Family

EfficientNet uses compound scaling—systematically scaling depth, width, and resolution together for optimal efficiency. It achieves better accuracy with fewer parameters than ResNet.

MobileNet Family

Designed for mobile and edge deployment. Uses depthwise separable convolutions to dramatically reduce parameters while maintaining reasonable accuracy.

VGG Family

Simple, uniform architecture using only 3×3 convolutions. Warning: Very large (138M+ params) due to fully connected layers. Mostly superseded by newer architectures but still useful for feature extraction.

Pretrained Model Architecture Explorer

Parameters
11.7M
Input Shape
3 × 224 × 224
Year
2015
Top-1 Accuracy
69.8%

Key Features:

  • Skip connections enable very deep networks
  • Residual learning: F(x) + x
  • Batch normalization after each conv

Layer Architecture:

1
conv1(conv)
9.4K64 × 112 × 112
7×7 conv, stride 2
2
bn1(norm)
12864 × 112 × 112
BatchNorm2d
3
relu(activation)
64 × 112 × 112
ReLU activation
4
maxpool(pool)
64 × 56 × 56
3×3 max pool, stride 2
5layer1(block)
148.0K64 × 56 × 56
2 BasicBlocks
6layer2(block)
525.6K128 × 28 × 28
2 BasicBlocks (downsample)
7layer3(block)
2.1M256 × 14 × 14
2 BasicBlocks (downsample)
8layer4(block)
8.4M512 × 7 × 7
2 BasicBlocks (downsample)
9
avgpool(pool)
512 × 1 × 1
Adaptive average pool
10
fc(fc)
513.0K1000
512 → 1000 classifier
Layer Types:
convpoolnormactivationfcblockhead

Choosing the Right Model

Use CaseRecommended ModelReasoning
Learning/PrototypingResNet-18/34Fast training, easy to understand
Production (accuracy)EfficientNet-B4/B5Best accuracy/compute tradeoff
Mobile/EdgeMobileNetV3Tiny size, fast inference
Feature ExtractionResNet-50Rich features, well-studied
Real-time VideoMobileNetV3-Large30+ FPS on mobile devices

Loading Pretrained Models

There are three main ways to load pretrained models in PyTorch. Each has its strengths.

Loading Pretrained Models

Official PyTorch model zoo with pretrained ImageNet weights

torchvision.models.py
1import torch
2import torchvision.models as models
3
4# Load pretrained ResNet-50
5model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
6
7# For feature extraction: freeze backbone
8for param in model.parameters():
9 param.requires_grad = False
10
11# Replace classifier for your task
12num_classes = 10
13model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
14
15# Only the new fc layer will be trained
16print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
L5Load ResNet-50 with IMAGENET1K_V2 weights (79.8% accuracy)
L8Freeze all parameters - they won't update during training
L13Replace the 1000-class head with your own classifier
Loading Models with torchvision.models
🐍load_torchvision.py
6Modern Weights API

PyTorch 1.13+ uses explicit weight enums instead of pretrained=True. This allows specifying exact weight versions for reproducibility.

EXAMPLE
weights=ResNet50_Weights.IMAGENET1K_V2
9Weight Variants

V2 weights use better training recipes (augmentation, longer training) achieving 80.9% vs 76.1% top-1 accuracy.

14Matching Transforms

Each weight variant comes with its exact preprocessing transforms. This ensures your input matches what the model expects.

EXAMPLE
Resize(232) → CenterCrop(224) → Normalize(mean, std)
21Evaluation Mode

model.eval() disables dropout and uses running statistics for BatchNorm. Critical for inference!

22No Gradient Context

torch.no_grad() disables gradient computation, reducing memory usage and speeding up inference.

23 lines without explanation
1import torch
2import torchvision.models as models
3from torchvision.models import ResNet50_Weights
4
5# Modern API (PyTorch 1.13+): specify weights explicitly
6model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
7
8# See available weight variants
9print(ResNet50_Weights.IMAGENET1K_V1)  # 76.1% top-1
10print(ResNet50_Weights.IMAGENET1K_V2)  # 80.9% top-1 (better augmentation)
11print(ResNet50_Weights.DEFAULT)        # Always the best available
12
13# Get preprocessing transforms for the weights
14preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
15
16# Apply to an image
17from PIL import Image
18img = Image.open("cat.jpg")
19input_tensor = preprocess(img).unsqueeze(0)  # Add batch dimension
20
21# Run inference
22model.eval()
23with torch.no_grad():
24    output = model(input_tensor)
25
26# Get predictions
27probabilities = torch.nn.functional.softmax(output[0], dim=0)
28top5_prob, top5_idx = torch.topk(probabilities, 5)

Method 2: torch.hub (For Third-Party Models)

🐍load_hub.py
1import torch
2
3# Load from official repos
4model = torch.hub.load('pytorch/vision', 'resnet50', weights='IMAGENET1K_V2')
5
6# Load cutting-edge models from research repos
7deit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
8swin = torch.hub.load('microsoft/Swin-Transformer', 'swin_base_patch4_window7_224')
9
10# List available models in a repo
11torch.hub.list('pytorch/vision')  # Returns list of model names
12
13# Get model documentation
14torch.hub.help('pytorch/vision', 'resnet50')

Method 3: timm Library (For Power Users)

🐍load_timm.py
1import timm
2
3# Search 800+ available models
4print(timm.list_models('*efficient*', pretrained=True)[:10])
5print(timm.list_models('*vit*', pretrained=True)[:10])
6
7# Load with automatic num_classes handling
8model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=10)
9
10# Get preprocessing config
11data_config = timm.data.resolve_model_data_config(model)
12transforms = timm.data.create_transform(**data_config, is_training=False)
13
14# Get just the feature extractor (no classifier)
15backbone = timm.create_model('efficientnet_b4', pretrained=True,
16                              num_classes=0, global_pool='')
17# Returns features with spatial dimensions: (B, C, H, W)

Which Method to Use?

  • torchvision.models: Best for standard models, official support, reproducibility
  • torch.hub: Best for loading models from research papers and GitHub repos
  • timm: Best for exploring many architectures, getting latest models, and flexibility

Quick Check

What does model.eval() do before inference?


Feature Extraction Strategy

Feature extraction treats the pretrained model as a fixed feature extractor. We freeze the backbone and only train a new classification head.

Pretrained Model Usage Flow

Feature Extraction: Freeze the pretrained backbone and only train a new classifier head. Fast training, works well with small datasets. The backbone acts as a fixed feature extractor.
Input Image
Frozen BackboneFROZEN
Feature Vector
New Classifier
Prediction
AspectFeature ExtractionFine-Tuning
Training SpeedFast (only classifier)Slower (all/most params)
Data RequiredSmall dataset OKNeeds more data
GPU MemoryLow (no gradient storage)High (full gradients)
PerformanceGood baselineBest results
Risk of OverfittingLowHigher (if small data)

When to Use Feature Extraction

  • Your dataset is small (hundreds to low thousands of images)
  • Your domain is similar to ImageNet (natural images, objects, animals)
  • You have limited compute or need fast iteration
  • You want to avoid overfitting

Implementation

Feature Extraction Implementation
🐍feature_extraction.py
10Freeze Parameters

Setting requires_grad=False prevents gradients from flowing into these parameters. They won't be updated during training.

15Get Input Features

model.fc.in_features tells us the size of the feature vector (2048 for ResNet-50) that feeds into the classifier.

EXAMPLE
ResNet-18: 512, ResNet-50: 2048, VGG: 4096
16New Classifier Head

We replace the 1000-class ImageNet head with our own. Adding dropout helps prevent overfitting the small new dataset.

23Verify Frozen State

Only ~20K params are trainable (0.09%) vs 23M total. This dramatically reduces overfitting risk and training time.

28Optimizer Setup

Only pass the new layer's parameters to the optimizer. Passing all parameters would work but wastes memory computing gradients for frozen layers.

25 lines without explanation
1import torch
2import torch.nn as nn
3import torchvision.models as models
4from torchvision.models import ResNet50_Weights
5
6# Step 1: Load pretrained model
7model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
8
9# Step 2: Freeze ALL parameters
10for param in model.parameters():
11    param.requires_grad = False
12
13# Step 3: Replace the classifier
14# ResNet-50's fc layer: Linear(2048, 1000)
15num_classes = 10  # Your number of classes
16model.fc = nn.Sequential(
17    nn.Dropout(0.2),
18    nn.Linear(model.fc.in_features, num_classes)
19)
20
21# Step 4: Only new layers are trainable
22trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
23total_params = sum(p.numel() for p in model.parameters())
24print(f"Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
25# Output: Trainable: 20,490 / 23,528,522 (0.09%)
26
27# Step 5: Use smaller learning rate (less to learn)
28optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
29
30# Note: Only pass model.fc.parameters() since backbone is frozen

Architecture-Specific Modifications

Different architectures have different classifier layer names:

🐍arch_specific.py
1# ResNet family
2model.fc = nn.Linear(model.fc.in_features, num_classes)
3
4# VGG family
5model.classifier[6] = nn.Linear(4096, num_classes)
6
7# EfficientNet (torchvision)
8model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
9
10# EfficientNet (timm)
11model.classifier = nn.Linear(model.classifier.in_features, num_classes)
12
13# MobileNetV3
14model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
15
16# DenseNet
17model.classifier = nn.Linear(model.classifier.in_features, num_classes)

Finding the Classifier Layer

Use print(model) to see the full architecture, then identify the final classification layer. Look for a Linear layer with 1000 output features (ImageNet classes).

Fine-Tuning Strategy

Fine-tuning unfreezes some or all of the pretrained layers, allowing them to adapt to your specific task. This can achieve better results but requires more care to avoid catastrophic forgetting.

When to Use Fine-Tuning

  • Your dataset is larger (thousands to hundreds of thousands of images)
  • Your domain is different from ImageNet (medical, satellite, microscopy)
  • Feature extraction has plateaued and you want more performance
  • You have sufficient compute for longer training

The Fine-Tuning Spectrum

Fine-tuning exists on a spectrum from minimal to full adaptation:

StrategyUnfrozen LayersData RequiredRisk
Feature ExtractionOnly classifier headSmall (100s)Low
Partial Fine-TuningLast 1-2 blocks + headMedium (1000s)Medium
Full Fine-TuningAll layersLarge (10K+)High

Implementing Gradual Unfreezing

Fine-Tuning with Gradual Unfreezing
🐍fine_tuning.py
12Selective Unfreezing

We unfreeze only the last residual block (layer4). This learns task-specific high-level features while preserving low-level feature detectors.

20Gradual Unfreezing Function

This helper unfreezes layers from the end backwards. Start training with just the classifier, then progressively unfreeze deeper layers.

29Training Schedule

First train the classifier to convergence, then unfreeze and fine-tune deeper layers with a lower learning rate.

29 lines without explanation
1import torch
2import torch.nn as nn
3import torchvision.models as models
4
5model = models.resnet50(weights='IMAGENET1K_V2')
6
7# Strategy 1: Unfreeze only the last block
8for param in model.parameters():
9    param.requires_grad = False
10
11# Unfreeze layer4 (the last residual block)
12for param in model.layer4.parameters():
13    param.requires_grad = True
14
15# Unfreeze the new classifier
16model.fc = nn.Linear(model.fc.in_features, num_classes)
17
18# Strategy 2: Gradual unfreezing during training
19def unfreeze_layers(model, layers_to_unfreeze):
20    """Gradually unfreeze layers during training."""
21    layer_groups = [model.layer1, model.layer2, model.layer3, model.layer4]
22
23    for i, layer in enumerate(layer_groups):
24        if i >= len(layer_groups) - layers_to_unfreeze:
25            for param in layer.parameters():
26                param.requires_grad = True
27            print(f"Unfroze layer{i+1}")
28
29# Training schedule:
30# Epoch 1-3: Train only classifier (frozen backbone)
31# Epoch 4-6: Unfreeze layer4
32# Epoch 7-10: Unfreeze layer3 and layer4

Differential Learning Rates

A critical technique: use smaller learning rates for pretrained layers and larger rates for new layers. This preserves the pretrained knowledge while allowing the new classifier to learn quickly.

Differential Learning Rates
🐍differential_lr.py
6Early Layers

Early layers detect universal features (edges, textures). Use very small LR (1e-5) to minimally disturb these well-learned representations.

EXAMPLE
1e-5 = 0.00001
10Later Layers

Later layers detect more task-specific features. Allow moderate adaptation with a 10× larger learning rate.

14New Layers

The new classifier must learn from scratch. Use the largest learning rate (100× larger than early layers).

28Rule of Thumb

A common heuristic: divide LR by 10 for each layer group moving towards the input. This creates a learning rate gradient that respects feature universality.

24 lines without explanation
1import torch.optim as optim
2
3# Different learning rates for different layer groups
4param_groups = [
5    # Pretrained early layers: very small LR (preserve features)
6    {'params': model.layer1.parameters(), 'lr': 1e-5},
7    {'params': model.layer2.parameters(), 'lr': 1e-5},
8
9    # Pretrained later layers: small LR (allow some adaptation)
10    {'params': model.layer3.parameters(), 'lr': 1e-4},
11    {'params': model.layer4.parameters(), 'lr': 1e-4},
12
13    # New classifier: normal LR (learn from scratch)
14    {'params': model.fc.parameters(), 'lr': 1e-3},
15]
16
17optimizer = optim.Adam(param_groups)
18
19# Alternative: 3-group strategy (simpler)
20param_groups_simple = [
21    {'params': list(model.layer1.parameters()) +
22               list(model.layer2.parameters()) +
23               list(model.layer3.parameters()), 'lr': 1e-5},
24    {'params': model.layer4.parameters(), 'lr': 1e-4},
25    {'params': model.fc.parameters(), 'lr': 1e-3},
26]
27
28# Rule of thumb: 10x smaller LR for each "older" layer group

Common Fine-Tuning Pitfalls

  • Learning rate too high: Destroys pretrained features (catastrophic forgetting)
  • Unfreezing too early: The classifier hasn't learned yet, so gradients are noisy
  • No differential LR: All layers update at the same rate, hurting early layers
  • Forgetting to use model.eval(): BatchNorm statistics are wrong during inference

Choosing the Right Strategy

Use this decision framework to choose between feature extraction and fine-tuning:

Your SituationRecommendationReasoning
Small data + similar domainFeature ExtractionPretrained features work well, avoid overfitting
Small data + different domainFeature Extraction + augmentationLimited data makes fine-tuning risky
Large data + similar domainLight Fine-TuningAdapt high-level features to your task
Large data + different domainFull Fine-TuningNeed to adapt all features to new domain
Very different domain (e.g., medical)Fine-tune from scratch or use domain-specific pretrainedImageNet features may not transfer well

Dataset Size Guidelines

Strategy={Feature Extractionif N<1000Partial Fine-Tuningif 1000N<10000Full Fine-Tuningif N10000\text{Strategy} = \begin{cases} \text{Feature Extraction} & \text{if } N < 1000 \\ \text{Partial Fine-Tuning} & \text{if } 1000 \leq N < 10000 \\ \text{Full Fine-Tuning} & \text{if } N \geq 10000 \end{cases}

Where NN is the number of training samples per class. These are rough guidelines—always validate with your specific data.

Domain Similarity Assessment

Very SimilarSomewhat SimilarVery Different
Natural photosMedical X-raysSatellite imagery
AnimalsMicroscopyRadar signals
ObjectsHandwritten textSpectrograms
ScenesIndustrial defectsScientific visualizations

Quick Check

You have 500 images of manufacturing defects (cracks, scratches) to classify. What strategy should you use?


Practical Workflow

Here's a complete workflow for adapting a pretrained model to a new classification task.

Step 1: Prepare Your Data

🐍prepare_data.py
1from torchvision import transforms, datasets
2from torchvision.models import ResNet50_Weights
3from torch.utils.data import DataLoader
4
5# Get the preprocessing transforms from the weights
6weights = ResNet50_Weights.IMAGENET1K_V2
7preprocess = weights.transforms()
8
9# For training: add augmentation BEFORE the model's preprocessing
10train_transforms = transforms.Compose([
11    transforms.RandomResizedCrop(224),
12    transforms.RandomHorizontalFlip(),
13    transforms.ColorJitter(brightness=0.2, contrast=0.2),
14    transforms.ToTensor(),
15    transforms.Normalize(mean=[0.485, 0.456, 0.406],
16                         std=[0.229, 0.224, 0.225]),
17])
18
19# For validation: use model's exact preprocessing
20val_transforms = preprocess
21
22# Load datasets
23train_dataset = datasets.ImageFolder('data/train', transform=train_transforms)
24val_dataset = datasets.ImageFolder('data/val', transform=val_transforms)
25
26train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
27val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
28
29print(f"Classes: {train_dataset.classes}")
30print(f"Training samples: {len(train_dataset)}")

Use the Correct Normalization

ImageNet models expect inputs normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]. Using different normalization will significantly hurt performance!

Step 2: Complete Training Pipeline

Complete Training Pipeline
🐍training_pipeline.py
18Dropout for Regularization

Adding dropout before the classifier helps prevent overfitting, especially important when the new dataset is small.

EXAMPLE
0.5 dropout = randomly zero 50% of features
27Learning Rate Scheduler

StepLR reduces learning rate by 0.1× every 5 epochs. This helps fine-tune the classifier more precisely as training progresses.

31model.train() Mode

Enables dropout and uses batch statistics for BatchNorm. Always call before training iterations.

53model.eval() Mode

Disables dropout and uses running statistics for BatchNorm. Always call before validation/inference.

57torch.no_grad() Context

Disables gradient computation during validation. Saves memory and speeds up inference.

75 lines without explanation
1import torch
2import torch.nn as nn
3from torchvision import models
4from torchvision.models import ResNet50_Weights
5from tqdm import tqdm
6
7# Setup
8device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9num_classes = len(train_dataset.classes)
10
11# Load and modify model
12model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
13
14# Freeze backbone for feature extraction
15for param in model.parameters():
16    param.requires_grad = False
17
18# New classifier with dropout
19model.fc = nn.Sequential(
20    nn.Dropout(0.5),
21    nn.Linear(model.fc.in_features, num_classes)
22)
23model = model.to(device)
24
25# Loss and optimizer (only train fc layer)
26criterion = nn.CrossEntropyLoss()
27optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
28scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
29
30# Training loop
31def train_epoch(model, loader, criterion, optimizer):
32    model.train()
33    running_loss = 0.0
34    correct = 0
35    total = 0
36
37    for images, labels in tqdm(loader, desc='Training'):
38        images, labels = images.to(device), labels.to(device)
39
40        optimizer.zero_grad()
41        outputs = model(images)
42        loss = criterion(outputs, labels)
43        loss.backward()
44        optimizer.step()
45
46        running_loss += loss.item()
47        _, predicted = outputs.max(1)
48        total += labels.size(0)
49        correct += predicted.eq(labels).sum().item()
50
51    return running_loss / len(loader), 100. * correct / total
52
53# Validation loop
54def validate(model, loader, criterion):
55    model.eval()
56    running_loss = 0.0
57    correct = 0
58    total = 0
59
60    with torch.no_grad():
61        for images, labels in tqdm(loader, desc='Validating'):
62            images, labels = images.to(device), labels.to(device)
63            outputs = model(images)
64            loss = criterion(outputs, labels)
65
66            running_loss += loss.item()
67            _, predicted = outputs.max(1)
68            total += labels.size(0)
69            correct += predicted.eq(labels).sum().item()
70
71    return running_loss / len(loader), 100. * correct / total
72
73# Train
74for epoch in range(10):
75    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
76    val_loss, val_acc = validate(model, val_loader, criterion)
77    scheduler.step()
78
79    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.2f}%')
80    print(f'         Val Loss={val_loss:.4f}, Acc={val_acc:.2f}%')

Step 3: Save and Load Your Model

🐍save_load.py
1# Save the entire model (architecture + weights)
2torch.save(model, 'model_complete.pth')
3
4# Recommended: Save only state dict (more flexible)
5torch.save({
6    'model_state_dict': model.state_dict(),
7    'optimizer_state_dict': optimizer.state_dict(),
8    'epoch': epoch,
9    'val_acc': val_acc,
10}, 'checkpoint.pth')
11
12# Load for inference
13model = models.resnet50()
14model.fc = nn.Linear(model.fc.in_features, num_classes)  # Must match!
15checkpoint = torch.load('checkpoint.pth')
16model.load_state_dict(checkpoint['model_state_dict'])
17model.eval()
18
19# Load for continued training
20model.load_state_dict(checkpoint['model_state_dict'])
21optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
22start_epoch = checkpoint['epoch'] + 1

Advanced Techniques

Using Models as Feature Extractors

Sometimes you want to extract features without any classification head, for example to use with a different ML algorithm or for similarity search.

🐍extract_features.py
1import torch
2from torchvision import models
3from torchvision.models import ResNet50_Weights
4
5# Method 1: Remove the fc layer
6model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
7model.fc = torch.nn.Identity()  # Replace fc with identity (pass-through)
8
9# Extract features
10model.eval()
11with torch.no_grad():
12    features = model(input_tensor)  # Shape: (batch, 2048)
13
14# Method 2: Use hook to get intermediate features
15features_dict = {}
16
17def get_features(name):
18    def hook(model, input, output):
19        features_dict[name] = output.detach()
20    return hook
21
22# Register hooks on layers of interest
23model.layer3.register_forward_hook(get_features('layer3'))
24model.layer4.register_forward_hook(get_features('layer4'))
25
26# Forward pass
27_ = model(input_tensor)
28
29# Access intermediate features
30layer3_features = features_dict['layer3']  # Shape: (batch, 1024, 14, 14)
31layer4_features = features_dict['layer4']  # Shape: (batch, 2048, 7, 7)

Multi-Task Learning with Pretrained Backbones

🐍multi_task.py
1class MultiTaskResNet(nn.Module):
2    def __init__(self, num_classes_task1, num_classes_task2):
3        super().__init__()
4
5        # Shared backbone
6        resnet = models.resnet50(weights='IMAGENET1K_V2')
7        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
8
9        # Task-specific heads
10        self.head_task1 = nn.Linear(2048, num_classes_task1)
11        self.head_task2 = nn.Linear(2048, num_classes_task2)
12
13    def forward(self, x):
14        features = self.backbone(x)
15        features = features.flatten(1)
16
17        out1 = self.head_task1(features)
18        out2 = self.head_task2(features)
19
20        return out1, out2
21
22# Training with multiple losses
23model = MultiTaskResNet(num_classes_task1=10, num_classes_task2=5)
24outputs1, outputs2 = model(images)
25
26loss = criterion1(outputs1, labels1) + 0.5 * criterion2(outputs2, labels2)

Knowledge Distillation from Larger Models

Use a large pretrained model as a "teacher" to train a smaller "student" model:

🐍distillation.py
1import torch.nn.functional as F
2
3# Teacher: large pretrained model (frozen)
4teacher = models.resnet152(weights='IMAGENET1K_V2')
5teacher.eval()
6for param in teacher.parameters():
7    param.requires_grad = False
8
9# Student: smaller model to train
10student = models.resnet18()  # Or MobileNet for efficiency
11
12# Distillation loss
13def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
14    """Combine soft targets from teacher with hard labels."""
15
16    # Soft targets: teacher's softened predictions
17    soft_targets = F.softmax(teacher_logits / T, dim=1)
18    soft_loss = F.kl_div(
19        F.log_softmax(student_logits / T, dim=1),
20        soft_targets,
21        reduction='batchmean'
22    ) * (T * T)
23
24    # Hard targets: ground truth labels
25    hard_loss = F.cross_entropy(student_logits, labels)
26
27    # Combine
28    return alpha * soft_loss + (1 - alpha) * hard_loss
29
30# Training
31with torch.no_grad():
32    teacher_logits = teacher(images)
33student_logits = student(images)
34loss = distillation_loss(student_logits, teacher_logits, labels)

When to Use Distillation

Knowledge distillation is useful when you need a smaller, faster model for deployment but want to retain the accuracy of a larger model. The teacher's soft predictions contain richer information than hard labels alone.

Summary

We've covered the practical aspects of using pretrained models:

TopicKey Points
Why Pretrained?Saves time, works with less data, better initialization
Model Sourcestorchvision.models, torch.hub, timm library
Feature ExtractionFreeze backbone, train only classifier, good for small data
Fine-TuningUnfreeze layers, use differential LR, good for larger data
Choosing StrategyDepends on data size and domain similarity
Best PracticesMatch preprocessing, use dropout, save checkpoints

Key Takeaways

  1. Pretrained models are essential—don't train from scratch unless you have millions of images and specific requirements
  2. Start with feature extraction—it's faster, less prone to overfitting, and often sufficient
  3. Fine-tune carefully—use differential learning rates and gradual unfreezing to preserve pretrained knowledge
  4. Match the preprocessing—always use the same normalization and transforms that the model was trained with
  5. Choose models wisely—ResNet-50 is a great default, but consider EfficientNet for better efficiency or MobileNet for edge deployment

Quick Check

1 / 4

When should you use feature extraction (frozen backbone) over fine-tuning?

Score: 0/4

Exercises

Conceptual Questions

  1. Explain why using a learning rate of 0.1 for fine-tuning a pretrained model would be problematic. What learning rate range is typically appropriate?
  2. You're fine-tuning a ResNet-50 and notice the validation accuracy dropping after epoch 5. What might be happening and how would you address it?
  3. Compare the memory requirements of feature extraction vs fine-tuning. Why does fine-tuning require more GPU memory?
  4. Why is it important to call model.eval() before inference? What specific layers behave differently in training vs evaluation mode?

Coding Exercises

  1. Multi-Model Comparison: Load ResNet-18, ResNet-50, and EfficientNet-B0. Compare their parameter counts, inference speed (time 100 forward passes), and top-5 predictions on a sample image.
  2. Feature Extraction Pipeline: Implement a complete feature extraction pipeline for a 5-class flower classification task. Include data augmentation, training/validation splits, early stopping, and model checkpointing.
  3. Gradual Unfreezing: Implement a training script that starts with feature extraction (frozen backbone), then gradually unfreezes layer4 after 3 epochs and layer3 after 6 epochs. Plot the training curves to show the effect of each unfreezing step.
  4. Learning Rate Finder: Implement a learning rate range test that gradually increases the learning rate from 1e-7 to 1 while recording the loss. Plot loss vs learning rate to find the optimal LR for fine-tuning.

Challenge Project

Domain-Specific Transfer Learning Study:

Choose a domain significantly different from ImageNet (e.g., chest X-rays, satellite imagery, or microscopy images). Systematically compare:

  1. Training from scratch vs transfer learning
  2. Feature extraction vs full fine-tuning
  3. Different pretrained backbones (ResNet, EfficientNet, ViT)
  4. The effect of different amounts of training data (10%, 25%, 50%, 100%)

Document your findings with learning curves and a final report comparing accuracy, training time, and computational cost for each approach.


In the next section, we'll explore techniques for visualizing what CNNs have learned, including activation maximization, gradient-based attribution, and feature visualization.