Introduction
In the previous section, we explored the theory of transfer learning—how knowledge from one task can accelerate learning on another. Now we turn to the practical side: how do we actually use pretrained models in PyTorch?
The deep learning community has invested millions of GPU hours training powerful models on massive datasets like ImageNet. These pretrained models are freely available, and learning to use them effectively is one of the most valuable skills in modern deep learning.
The Practical Reality: Most production computer vision systems don't train from scratch. They leverage pretrained models, saving weeks of training time and achieving better results with less data.
By the end of this section, you'll know how to load any pretrained model, adapt it to your specific task, and choose the optimal training strategy based on your data and computational constraints.
Learning Objectives
After completing this section, you will be able to:
- Load pretrained models from torchvision, torch.hub, and the timm library
- Understand model architectures and identify which layers to modify for your task
- Implement feature extraction by freezing backbone weights and training only the classifier
- Implement fine-tuning with differential learning rates and gradual unfreezing
- Choose the right strategy based on dataset size, domain similarity, and computational budget
- Apply best practices for preprocessing, learning rates, and avoiding common pitfalls
Why This Matters
Why Use Pretrained Models?
Before diving into code, let's understand the compelling reasons to use pretrained models rather than training from scratch.
The Economics of Training
Training a state-of-the-art image classification model from scratch requires:
| Resource | Training from Scratch | Using Pretrained |
|---|---|---|
| GPU Hours | 100-1000+ hours | 1-10 hours |
| Training Data | 1M+ images | 100-10K images |
| Cloud Cost | $1,000-$50,000+ | $10-$100 |
| Engineering Time | Weeks to months | Hours to days |
| Expertise Required | Architecture design, hyperparameter tuning | Basic transfer learning |
The Feature Hierarchy Advantage
Pretrained models have already learned a rich hierarchy of visual features:
These early features (edges, textures, shapes) are universal—they transfer to virtually any visual domain. Even if you're classifying satellite imagery or medical scans, the low-level features learned from natural images provide an excellent starting point.
| Layer Depth | Features Learned | Transferability |
|---|---|---|
| Early (1-2) | Edges, color gradients, Gabor-like filters | Universal (99%+ domains) |
| Middle (3-4) | Textures, corners, simple shapes | Very high (90%+ domains) |
| Late (5-6) | Object parts, complex patterns | High for similar domains |
| Final (FC) | Task-specific combinations | Low (must retrain) |
Mathematical Intuition
From an optimization perspective, pretrained weights provide a better initializationin the loss landscape. Let be the loss function and be the optimal parameters. Starting from pretrained weights vs. random initialization :
This means gradient descent has a much shorter path to travel, leading to faster convergence and often finding better local minima.
Quick Check
Which layer type from a pretrained ImageNet model is LEAST likely to transfer well to a medical X-ray classification task?
The ImageNet Revolution
To understand pretrained models, we need to appreciate their origin: ImageNet—a dataset that transformed computer vision.
What is ImageNet?
ImageNet is a large-scale visual recognition dataset created by Fei-Fei Li and her team at Stanford:
| Property | Value |
|---|---|
| Total Images | 14+ million |
| Classes | 21,841 categories (WordNet synsets) |
| ILSVRC Subset | 1.2 million images, 1000 classes |
| Image Resolution | Variable, typically 224×224 after preprocessing |
| Annotation | Human-verified bounding boxes and labels |
ILSVRC: The Competition That Changed Everything
The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) ran from 2010-2017 and drove dramatic improvements in image classification:
| Year | Winner | Top-5 Error | Key Innovation |
|---|---|---|---|
| 2010 | NEC-UIUC | 28.2% | SIFT + Fisher vectors |
| 2011 | XRCE | 25.8% | Compressed Fisher vectors |
| 2012 | AlexNet | 16.4% | Deep CNNs + GPU training |
| 2014 | VGGNet | 7.3% | Very deep (19 layers), 3×3 convs |
| 2014 | GoogLeNet | 6.7% | Inception modules |
| 2015 | ResNet | 3.6% | Skip connections (152 layers) |
| 2017 | SENet | 2.3% | Squeeze-and-excitation blocks |
Human-Level Performance
Why ImageNet Weights Transfer So Well
ImageNet pretraining works remarkably well across domains because:
- Diversity: 1000 classes spanning animals, objects, scenes, and textures force the model to learn general-purpose features
- Scale: 1.2 million images provide enough data to learn robust, non-spurious features
- Hierarchy: The category structure (from breeds to species to animals) encourages learning at multiple abstraction levels
- Community Effort: Decades of research have optimized architectures specifically for ImageNet, resulting in well-tuned, efficient models
The PyTorch Model Zoo
PyTorch provides access to dozens of pretrained models through torchvision.models. Let's explore what's available and how to choose the right model.
Model Zoo Comparison
| Model | Params (M) | GFLOPs | Top-1 (%) ↓ | Inference (ms) | Efficiency |
|---|---|---|---|---|---|
| EfficientNet-B4 | 19 | 4.2 | 82.9 | 21.3 | 4.4 acc/M |
| ResNet-101 | 44.5 | 7.8 | 77.4 | 14.2 | 1.7 acc/M |
| EfficientNet-B0 | 5.3 | 0.4 | 77.1 | 5.8 | 14.5 acc/M |
| ResNet-50 | 25.6 | 4.1 | 76.1 | 8.3 | 3.0 acc/M |
| MobileNetV3-L | 5.4 | 0.22 | 75.2 | 3.8 | 13.9 acc/M |
| DenseNet-121 | 8 | 2.9 | 74.4 | 11.5 | 9.3 acc/M |
| ResNet-34 | 21.8 | 3.7 | 73.3 | 6.5 | 3.4 acc/M |
| VGG-19 | 143.7 | 19.6 | 72.4 | 15.2 | 0.5 acc/M |
| VGG-16 | 138.4 | 15.5 | 71.6 | 12.8 | 0.5 acc/M |
| ResNet-18 | 11.7 | 1.8 | 69.8 | 4.2 | 6.0 acc/M |
| MobileNetV3-S | 2.5 | 0.06 | 67.7 | 2.1 | 27.1 acc/M |
Click column headers to sort. Efficiency score = Top-1 accuracy per million parameters. Higher is better for resource-constrained deployments.
Model Families Overview
ResNet Family
The ResNet (Residual Network) family introduced skip connections that enabled training of very deep networks. The key insight: learning is easier than learning directly.
1# ResNet variants - trade accuracy for speed
2models.resnet18() # 11.7M params, 69.8% top-1
3models.resnet34() # 21.8M params, 73.3% top-1
4models.resnet50() # 25.6M params, 76.1% top-1 (most popular)
5models.resnet101() # 44.5M params, 77.4% top-1
6models.resnet152() # 60.2M params, 78.3% top-1EfficientNet Family
EfficientNet uses compound scaling—systematically scaling depth, width, and resolution together for optimal efficiency. It achieves better accuracy with fewer parameters than ResNet.
MobileNet Family
Designed for mobile and edge deployment. Uses depthwise separable convolutions to dramatically reduce parameters while maintaining reasonable accuracy.
VGG Family
Simple, uniform architecture using only 3×3 convolutions. Warning: Very large (138M+ params) due to fully connected layers. Mostly superseded by newer architectures but still useful for feature extraction.
Pretrained Model Architecture Explorer
Key Features:
- ✓Skip connections enable very deep networks
- ✓Residual learning: F(x) + x
- ✓Batch normalization after each conv
Layer Architecture:
Choosing the Right Model
| Use Case | Recommended Model | Reasoning |
|---|---|---|
| Learning/Prototyping | ResNet-18/34 | Fast training, easy to understand |
| Production (accuracy) | EfficientNet-B4/B5 | Best accuracy/compute tradeoff |
| Mobile/Edge | MobileNetV3 | Tiny size, fast inference |
| Feature Extraction | ResNet-50 | Rich features, well-studied |
| Real-time Video | MobileNetV3-Large | 30+ FPS on mobile devices |
Loading Pretrained Models
There are three main ways to load pretrained models in PyTorch. Each has its strengths.
Loading Pretrained Models
Official PyTorch model zoo with pretrained ImageNet weights
1import torch2import torchvision.models as models3 4# Load pretrained ResNet-505model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)6 7# For feature extraction: freeze backbone8for param in model.parameters():9 param.requires_grad = False10 11# Replace classifier for your task12num_classes = 1013model.fc = torch.nn.Linear(model.fc.in_features, num_classes)14 15# Only the new fc layer will be trained16print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")Method 1: torchvision.models (Recommended for Beginners)
Method 2: torch.hub (For Third-Party Models)
1import torch
2
3# Load from official repos
4model = torch.hub.load('pytorch/vision', 'resnet50', weights='IMAGENET1K_V2')
5
6# Load cutting-edge models from research repos
7deit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
8swin = torch.hub.load('microsoft/Swin-Transformer', 'swin_base_patch4_window7_224')
9
10# List available models in a repo
11torch.hub.list('pytorch/vision') # Returns list of model names
12
13# Get model documentation
14torch.hub.help('pytorch/vision', 'resnet50')Method 3: timm Library (For Power Users)
1import timm
2
3# Search 800+ available models
4print(timm.list_models('*efficient*', pretrained=True)[:10])
5print(timm.list_models('*vit*', pretrained=True)[:10])
6
7# Load with automatic num_classes handling
8model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=10)
9
10# Get preprocessing config
11data_config = timm.data.resolve_model_data_config(model)
12transforms = timm.data.create_transform(**data_config, is_training=False)
13
14# Get just the feature extractor (no classifier)
15backbone = timm.create_model('efficientnet_b4', pretrained=True,
16 num_classes=0, global_pool='')
17# Returns features with spatial dimensions: (B, C, H, W)Which Method to Use?
- torchvision.models: Best for standard models, official support, reproducibility
- torch.hub: Best for loading models from research papers and GitHub repos
- timm: Best for exploring many architectures, getting latest models, and flexibility
Quick Check
What does model.eval() do before inference?
Feature Extraction Strategy
Feature extraction treats the pretrained model as a fixed feature extractor. We freeze the backbone and only train a new classification head.
Pretrained Model Usage Flow
| Aspect | Feature Extraction | Fine-Tuning |
|---|---|---|
| Training Speed | Fast (only classifier) | Slower (all/most params) |
| Data Required | Small dataset OK | Needs more data |
| GPU Memory | Low (no gradient storage) | High (full gradients) |
| Performance | Good baseline | Best results |
| Risk of Overfitting | Low | Higher (if small data) |
When to Use Feature Extraction
- Your dataset is small (hundreds to low thousands of images)
- Your domain is similar to ImageNet (natural images, objects, animals)
- You have limited compute or need fast iteration
- You want to avoid overfitting
Implementation
Architecture-Specific Modifications
Different architectures have different classifier layer names:
1# ResNet family
2model.fc = nn.Linear(model.fc.in_features, num_classes)
3
4# VGG family
5model.classifier[6] = nn.Linear(4096, num_classes)
6
7# EfficientNet (torchvision)
8model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
9
10# EfficientNet (timm)
11model.classifier = nn.Linear(model.classifier.in_features, num_classes)
12
13# MobileNetV3
14model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
15
16# DenseNet
17model.classifier = nn.Linear(model.classifier.in_features, num_classes)Finding the Classifier Layer
print(model) to see the full architecture, then identify the final classification layer. Look for a Linear layer with 1000 output features (ImageNet classes).Fine-Tuning Strategy
Fine-tuning unfreezes some or all of the pretrained layers, allowing them to adapt to your specific task. This can achieve better results but requires more care to avoid catastrophic forgetting.
When to Use Fine-Tuning
- Your dataset is larger (thousands to hundreds of thousands of images)
- Your domain is different from ImageNet (medical, satellite, microscopy)
- Feature extraction has plateaued and you want more performance
- You have sufficient compute for longer training
The Fine-Tuning Spectrum
Fine-tuning exists on a spectrum from minimal to full adaptation:
| Strategy | Unfrozen Layers | Data Required | Risk |
|---|---|---|---|
| Feature Extraction | Only classifier head | Small (100s) | Low |
| Partial Fine-Tuning | Last 1-2 blocks + head | Medium (1000s) | Medium |
| Full Fine-Tuning | All layers | Large (10K+) | High |
Implementing Gradual Unfreezing
Differential Learning Rates
A critical technique: use smaller learning rates for pretrained layers and larger rates for new layers. This preserves the pretrained knowledge while allowing the new classifier to learn quickly.
Common Fine-Tuning Pitfalls
- Learning rate too high: Destroys pretrained features (catastrophic forgetting)
- Unfreezing too early: The classifier hasn't learned yet, so gradients are noisy
- No differential LR: All layers update at the same rate, hurting early layers
- Forgetting to use model.eval(): BatchNorm statistics are wrong during inference
Choosing the Right Strategy
Use this decision framework to choose between feature extraction and fine-tuning:
| Your Situation | Recommendation | Reasoning |
|---|---|---|
| Small data + similar domain | Feature Extraction | Pretrained features work well, avoid overfitting |
| Small data + different domain | Feature Extraction + augmentation | Limited data makes fine-tuning risky |
| Large data + similar domain | Light Fine-Tuning | Adapt high-level features to your task |
| Large data + different domain | Full Fine-Tuning | Need to adapt all features to new domain |
| Very different domain (e.g., medical) | Fine-tune from scratch or use domain-specific pretrained | ImageNet features may not transfer well |
Dataset Size Guidelines
Where is the number of training samples per class. These are rough guidelines—always validate with your specific data.
Domain Similarity Assessment
| Very Similar | Somewhat Similar | Very Different |
|---|---|---|
| Natural photos | Medical X-rays | Satellite imagery |
| Animals | Microscopy | Radar signals |
| Objects | Handwritten text | Spectrograms |
| Scenes | Industrial defects | Scientific visualizations |
Quick Check
You have 500 images of manufacturing defects (cracks, scratches) to classify. What strategy should you use?
Practical Workflow
Here's a complete workflow for adapting a pretrained model to a new classification task.
Step 1: Prepare Your Data
1from torchvision import transforms, datasets
2from torchvision.models import ResNet50_Weights
3from torch.utils.data import DataLoader
4
5# Get the preprocessing transforms from the weights
6weights = ResNet50_Weights.IMAGENET1K_V2
7preprocess = weights.transforms()
8
9# For training: add augmentation BEFORE the model's preprocessing
10train_transforms = transforms.Compose([
11 transforms.RandomResizedCrop(224),
12 transforms.RandomHorizontalFlip(),
13 transforms.ColorJitter(brightness=0.2, contrast=0.2),
14 transforms.ToTensor(),
15 transforms.Normalize(mean=[0.485, 0.456, 0.406],
16 std=[0.229, 0.224, 0.225]),
17])
18
19# For validation: use model's exact preprocessing
20val_transforms = preprocess
21
22# Load datasets
23train_dataset = datasets.ImageFolder('data/train', transform=train_transforms)
24val_dataset = datasets.ImageFolder('data/val', transform=val_transforms)
25
26train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
27val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
28
29print(f"Classes: {train_dataset.classes}")
30print(f"Training samples: {len(train_dataset)}")Use the Correct Normalization
Step 2: Complete Training Pipeline
Step 3: Save and Load Your Model
1# Save the entire model (architecture + weights)
2torch.save(model, 'model_complete.pth')
3
4# Recommended: Save only state dict (more flexible)
5torch.save({
6 'model_state_dict': model.state_dict(),
7 'optimizer_state_dict': optimizer.state_dict(),
8 'epoch': epoch,
9 'val_acc': val_acc,
10}, 'checkpoint.pth')
11
12# Load for inference
13model = models.resnet50()
14model.fc = nn.Linear(model.fc.in_features, num_classes) # Must match!
15checkpoint = torch.load('checkpoint.pth')
16model.load_state_dict(checkpoint['model_state_dict'])
17model.eval()
18
19# Load for continued training
20model.load_state_dict(checkpoint['model_state_dict'])
21optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
22start_epoch = checkpoint['epoch'] + 1Advanced Techniques
Using Models as Feature Extractors
Sometimes you want to extract features without any classification head, for example to use with a different ML algorithm or for similarity search.
1import torch
2from torchvision import models
3from torchvision.models import ResNet50_Weights
4
5# Method 1: Remove the fc layer
6model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
7model.fc = torch.nn.Identity() # Replace fc with identity (pass-through)
8
9# Extract features
10model.eval()
11with torch.no_grad():
12 features = model(input_tensor) # Shape: (batch, 2048)
13
14# Method 2: Use hook to get intermediate features
15features_dict = {}
16
17def get_features(name):
18 def hook(model, input, output):
19 features_dict[name] = output.detach()
20 return hook
21
22# Register hooks on layers of interest
23model.layer3.register_forward_hook(get_features('layer3'))
24model.layer4.register_forward_hook(get_features('layer4'))
25
26# Forward pass
27_ = model(input_tensor)
28
29# Access intermediate features
30layer3_features = features_dict['layer3'] # Shape: (batch, 1024, 14, 14)
31layer4_features = features_dict['layer4'] # Shape: (batch, 2048, 7, 7)Multi-Task Learning with Pretrained Backbones
1class MultiTaskResNet(nn.Module):
2 def __init__(self, num_classes_task1, num_classes_task2):
3 super().__init__()
4
5 # Shared backbone
6 resnet = models.resnet50(weights='IMAGENET1K_V2')
7 self.backbone = nn.Sequential(*list(resnet.children())[:-1])
8
9 # Task-specific heads
10 self.head_task1 = nn.Linear(2048, num_classes_task1)
11 self.head_task2 = nn.Linear(2048, num_classes_task2)
12
13 def forward(self, x):
14 features = self.backbone(x)
15 features = features.flatten(1)
16
17 out1 = self.head_task1(features)
18 out2 = self.head_task2(features)
19
20 return out1, out2
21
22# Training with multiple losses
23model = MultiTaskResNet(num_classes_task1=10, num_classes_task2=5)
24outputs1, outputs2 = model(images)
25
26loss = criterion1(outputs1, labels1) + 0.5 * criterion2(outputs2, labels2)Knowledge Distillation from Larger Models
Use a large pretrained model as a "teacher" to train a smaller "student" model:
1import torch.nn.functional as F
2
3# Teacher: large pretrained model (frozen)
4teacher = models.resnet152(weights='IMAGENET1K_V2')
5teacher.eval()
6for param in teacher.parameters():
7 param.requires_grad = False
8
9# Student: smaller model to train
10student = models.resnet18() # Or MobileNet for efficiency
11
12# Distillation loss
13def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
14 """Combine soft targets from teacher with hard labels."""
15
16 # Soft targets: teacher's softened predictions
17 soft_targets = F.softmax(teacher_logits / T, dim=1)
18 soft_loss = F.kl_div(
19 F.log_softmax(student_logits / T, dim=1),
20 soft_targets,
21 reduction='batchmean'
22 ) * (T * T)
23
24 # Hard targets: ground truth labels
25 hard_loss = F.cross_entropy(student_logits, labels)
26
27 # Combine
28 return alpha * soft_loss + (1 - alpha) * hard_loss
29
30# Training
31with torch.no_grad():
32 teacher_logits = teacher(images)
33student_logits = student(images)
34loss = distillation_loss(student_logits, teacher_logits, labels)When to Use Distillation
Summary
We've covered the practical aspects of using pretrained models:
| Topic | Key Points |
|---|---|
| Why Pretrained? | Saves time, works with less data, better initialization |
| Model Sources | torchvision.models, torch.hub, timm library |
| Feature Extraction | Freeze backbone, train only classifier, good for small data |
| Fine-Tuning | Unfreeze layers, use differential LR, good for larger data |
| Choosing Strategy | Depends on data size and domain similarity |
| Best Practices | Match preprocessing, use dropout, save checkpoints |
Key Takeaways
- Pretrained models are essential—don't train from scratch unless you have millions of images and specific requirements
- Start with feature extraction—it's faster, less prone to overfitting, and often sufficient
- Fine-tune carefully—use differential learning rates and gradual unfreezing to preserve pretrained knowledge
- Match the preprocessing—always use the same normalization and transforms that the model was trained with
- Choose models wisely—ResNet-50 is a great default, but consider EfficientNet for better efficiency or MobileNet for edge deployment
Quick Check
When should you use feature extraction (frozen backbone) over fine-tuning?
Exercises
Conceptual Questions
- Explain why using a learning rate of 0.1 for fine-tuning a pretrained model would be problematic. What learning rate range is typically appropriate?
- You're fine-tuning a ResNet-50 and notice the validation accuracy dropping after epoch 5. What might be happening and how would you address it?
- Compare the memory requirements of feature extraction vs fine-tuning. Why does fine-tuning require more GPU memory?
- Why is it important to call
model.eval()before inference? What specific layers behave differently in training vs evaluation mode?
Coding Exercises
- Multi-Model Comparison: Load ResNet-18, ResNet-50, and EfficientNet-B0. Compare their parameter counts, inference speed (time 100 forward passes), and top-5 predictions on a sample image.
- Feature Extraction Pipeline: Implement a complete feature extraction pipeline for a 5-class flower classification task. Include data augmentation, training/validation splits, early stopping, and model checkpointing.
- Gradual Unfreezing: Implement a training script that starts with feature extraction (frozen backbone), then gradually unfreezes layer4 after 3 epochs and layer3 after 6 epochs. Plot the training curves to show the effect of each unfreezing step.
- Learning Rate Finder: Implement a learning rate range test that gradually increases the learning rate from 1e-7 to 1 while recording the loss. Plot loss vs learning rate to find the optimal LR for fine-tuning.
Challenge Project
Domain-Specific Transfer Learning Study:
Choose a domain significantly different from ImageNet (e.g., chest X-rays, satellite imagery, or microscopy images). Systematically compare:
- Training from scratch vs transfer learning
- Feature extraction vs full fine-tuning
- Different pretrained backbones (ResNet, EfficientNet, ViT)
- The effect of different amounts of training data (10%, 25%, 50%, 100%)
Document your findings with learning curves and a final report comparing accuracy, training time, and computational cost for each approach.
In the next section, we'll explore techniques for visualizing what CNNs have learned, including activation maximization, gradient-based attribution, and feature visualization.