Chapter 24
25 min read
Section 128 of 178

Pretext Tasks for Images

Self-Supervised Learning

Learning Objectives

By the end of this section, you will be able to:

  1. Understand pretext tasks — Learn what pretext tasks are and why they enable learning from unlabeled images
  2. Master classic image pretext tasks — Implement rotation prediction, jigsaw puzzles, colorization, and context prediction
  3. Analyze the mathematics — Understand the loss functions and optimization objectives for each pretext task
  4. Build intuition — Know why solving these "puzzle" tasks forces networks to learn useful visual representations
  5. Compare approaches — Evaluate trade-offs between different pretext tasks and their effectiveness for downstream applications

The Big Picture

In 2015-2018, deep learning faced a paradox: neural networks were getting better, but they were also getting hungrier for labeled data. Training a state-of-the-art image classifier required millions of human-annotated images. Meanwhile, the internet was overflowing with billions of unlabeled images. How could we tap into this vast resource?

The breakthrough came from a simple observation: images contain their own supervision signals. If you rotate an image, the image "knows" it was rotated. If you remove a patch, the surrounding context "knows" what should be there. If you convert to grayscale, the original colors are implicit in the structure.

The Core Insight: By designing tasks where the labels can be automatically generated from the data itself, we can train networks on unlimited unlabeled data. The representations learned to solve these "pretext tasks" transfer remarkably well to real downstream tasks like classification and detection.

This section explores the pioneering pretext tasks that launched the self-supervised learning revolution for computer vision. Each task exploits a different structural property of images to create free supervision. While newer methods like contrastive learning have largely superseded these approaches, understanding pretext tasks provides essential intuition for why self-supervised learning works.


What Are Pretext Tasks?

A pretext task is a task where the "labels" are automatically derived from the data itself, without any human annotation. The network learns to solve this task, and in doing so, learns representations that are useful for other tasks we actually care about (the "downstream tasks").

The Pretext Task Framework

The general framework for pretext task learning consists of three steps:

  1. Transform: Apply a transformation TT to an image xx to get x~=T(x)\tilde{x} = T(x)
  2. Predict: Train a network fθf_\theta to predict some property of TT from x~\tilde{x}
  3. Transfer: Use the learned representations for downstream tasks

Mathematically, we minimize:

Lpretext=ExD[(fθ(T(x)),yT)]\mathcal{L}_{\text{pretext}} = \mathbb{E}_{x \sim \mathcal{D}} \left[ \ell(f_\theta(T(x)), y_T) \right]

where yTy_T is the automatically generated label encoding information about transformation TT.

Key Design Principles

PrincipleWhy It MattersExample
Non-trivialTask must require understanding visual content, not just low-level statisticsRotation requires understanding object orientation
LearnableTask must be solvable by a neural networkColorization has natural image statistics to exploit
TransferableFeatures learned should be useful for downstream tasksSpatial reasoning transfers to detection
No shortcutsNetwork cannot cheat using trivial cuesChromatic aberration reveals rotation → must remove it

Rotation Prediction (RotNet)

Rotation prediction, introduced by Gidaris et al. in 2018, is one of the simplest yet surprisingly effective pretext tasks. The idea: rotate images by 0°, 90°, 180°, or 270°, and train a network to predict which rotation was applied.

The Intuition

Why would predicting rotation teach useful features? Consider what the network must learn:

  • Object recognition: A rotated cat is still recognizable as a cat, but its orientation matters
  • Scene understanding: Skies are typically at the top, ground at the bottom
  • Semantic parts: Faces have eyes above mouth, buildings have roofs on top

To predict rotation correctly, the network must implicitly learn these semantic concepts.

Mathematical Formulation

Let RkR_k denote rotation by k×90°k \times 90° for k{0,1,2,3}k \in \{0, 1, 2, 3\}. The rotation prediction task is a 4-class classification problem:

Lrot=k=03ExD[logpθ(y=kRk(x))]\mathcal{L}_{\text{rot}} = -\sum_{k=0}^{3} \mathbb{E}_{x \sim \mathcal{D}} \left[ \log p_\theta(y=k | R_k(x)) \right]

where pθ(y=kRk(x))p_\theta(y=k | R_k(x)) is the softmax probability for rotation class kk.

Interactive Demo

Explore how rotation prediction works. The network must identify which rotation was applied to correctly classify the image:

Rotation Prediction (RotNet)

Accuracy: 0.0%(0/0)
Input Image (Rotated by 0°)
Network Prediction: What rotation was applied?
How RotNet Works:
  1. Take an image and rotate it by 0°, 90°, 180°, or 270°
  2. Network predicts which rotation was applied
  3. Learning to predict rotation forces understanding of object structure
  4. The learned features transfer well to downstream tasks

Implementation

Here's a complete PyTorch implementation of RotNet:

RotNet Implementation
🐍rotnet.py
1Imports

Import PyTorch core modules and torchvision for transforms and pretrained models.

6RotNet Class

The main model class that predicts rotation angle. Inherits from nn.Module for PyTorch integration.

84-Class Output

Only 4 classes needed: 0°, 90°, 180°, 270°. This makes it a simple classification problem.

11Backbone Network

Use ResNet18 without pretrained weights - we want to learn features from scratch via the rotation task.

13Replace Classifier

Replace the 1000-class ImageNet classifier with a 4-class rotation classifier.

19Rotation Batch Function

Key function that creates the self-supervised training data by rotating each image 4 ways.

25Four Rotations

torch.rot90 efficiently rotates tensors. k=0 is 0°, k=1 is 90°, k=2 is 180°, k=3 is 270°.

31Stack Results

Combine all rotated images into a single batch with corresponding rotation labels.

34Training Function

Standard PyTorch training loop adapted for the rotation prediction task.

35Cross-Entropy Loss

Standard classification loss since rotation prediction is a 4-way classification problem.

42Ignore Original Labels

Critical: we ignore the original dataset labels since we create our own supervision from rotations.

45Create Rotations

Transform each batch into 4× the samples with rotation labels.

57Accuracy Tracking

Track how well the network predicts rotations. High accuracy (>90%) indicates learning useful features.

52 lines without explanation
1import torch
2import torch.nn as nn
3import torchvision.transforms as T
4from torchvision.models import resnet18
5
6class RotNet(nn.Module):
7    """Rotation prediction network for self-supervised learning."""
8
9    def __init__(self, num_classes=4):
10        super().__init__()
11        # Use ResNet18 as backbone (without pretrained weights)
12        self.backbone = resnet18(weights=None)
13        # Replace final layer for 4-way classification
14        num_features = self.backbone.fc.in_features
15        self.backbone.fc = nn.Linear(num_features, num_classes)
16
17    def forward(self, x):
18        return self.backbone(x)
19
20def create_rotation_batch(images):
21    """Create batch with 4 rotations per image."""
22    batch_size = images.size(0)
23    rotated_images = []
24    labels = []
25
26    for i in range(batch_size):
27        img = images[i]
28        for rotation in [0, 1, 2, 3]:  # 0°, 90°, 180°, 270°
29            rotated = torch.rot90(img, k=rotation, dims=[1, 2])
30            rotated_images.append(rotated)
31            labels.append(rotation)
32
33    return torch.stack(rotated_images), torch.tensor(labels)
34
35# Training loop
36def train_rotnet(model, dataloader, optimizer, device, epochs=10):
37    criterion = nn.CrossEntropyLoss()
38    model.train()
39
40    for epoch in range(epochs):
41        total_loss = 0
42        correct = 0
43        total = 0
44
45        for images, _ in dataloader:  # Ignore original labels
46            images = images.to(device)
47
48            # Create rotated batch
49            rotated_images, rotation_labels = create_rotation_batch(images)
50            rotated_images = rotated_images.to(device)
51            rotation_labels = rotation_labels.to(device)
52
53            optimizer.zero_grad()
54            outputs = model(rotated_images)
55            loss = criterion(outputs, rotation_labels)
56            loss.backward()
57            optimizer.step()
58
59            total_loss += loss.item()
60            _, predicted = outputs.max(1)
61            total += rotation_labels.size(0)
62            correct += predicted.eq(rotation_labels).sum().item()
63
64        accuracy = 100. * correct / total
65        print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Acc={accuracy:.2f}%")

Avoiding Shortcuts

Certain image artifacts can reveal the rotation without understanding content. For example, chromatic aberration (color fringing) at image edges indicates orientation. Proper data augmentation and center cropping are essential to prevent the network from cheating.

Jigsaw Puzzle Solving

The jigsaw puzzle task, proposed by Noroozi and Favaro in 2016, takes self-supervision further by exploiting spatial relationships between image regions. Split an image into a 3×3 grid, shuffle the patches according to one of many possible permutations, and train the network to identify which permutation was used.

Why Jigsaw Works

To solve a jigsaw puzzle, you must understand:

  • Object parts: Which patch contains the head, body, tail?
  • Spatial coherence: Parts must fit together semantically
  • Texture continuity: Adjacent patches should have matching textures
  • Edge alignment: Object boundaries should continue across patches

Mathematical Formulation

Let Π={π1,π2,...,πN}\Pi = \{\pi_1, \pi_2, ..., \pi_N\} be a set of NN permutations (typically N=100N = 100 to10001000, selected from 9!=362,8809! = 362,880 possibilities).

For an image split into 9 patches {p1,...,p9}\{p_1, ..., p_9\}, applying permutation πk\pi_k gives shuffled patches. The task is:

Ljigsaw=Ex,πk[logpθ(y=k{pπk(1),...,pπk(9)})]\mathcal{L}_{\text{jigsaw}} = -\mathbb{E}_{x, \pi_k} \left[ \log p_\theta(y=k | \{p_{\pi_k(1)}, ..., p_{\pi_k(9)}\}) \right]

Interactive Demo

See how the jigsaw puzzle task works. Observe how patches are shuffled and the network must identify the permutation class:

Jigsaw Puzzle Pretext Task

Shuffled Patches (Input to Network)
Current Permutation
[0, 1, 2, 3, 4, 5, 6, 7, 8]
Class 0 of 9
Patch Gap:4px
Permutation Classes (Network Output)
How Jigsaw Puzzle Works:
  1. Split image into a 3×3 grid (9 patches)
  2. Shuffle patches using one of N predefined permutations
  3. Network predicts which permutation class was used
  4. N is typically 100-1000 (subset of 9! = 362,880)
  5. Forces network to understand spatial relationships
Key Insight:

To solve the puzzle, the network must learn features that capture object parts and their spatial relationships - the same features useful for recognition tasks.

Architecture Considerations

The jigsaw network uses a Siamese architecture where all 9 patches are processed by the same encoder with shared weights:

Jigsaw Puzzle Network
🐍jigsaw_net.py
6JigsawNet Class

Solves jigsaw puzzles by predicting which permutation was applied to image patches.

9Permutation Classes

Uses a subset of possible permutations (typically 100-1000) since 9! = 362,880 is too many classes.

12Siamese Architecture

All 9 patches processed by the same encoder with shared weights - ensures consistent feature extraction.

14Feature Combination

Concatenate features from all 9 patches (9 × 512 = 4608) before classification.

21Permutation Set

Pre-compute a fixed set of permutations used during training. Selected for maximum diversity.

38Hamming Distance

Select permutations that are maximally different from each other to make the task challenging.

52Forward Pass

Process each of the 9 patches through the shared encoder, then concatenate for classification.

60Patch Extraction

Split image into 3×3 grid of patches. Patches may have gaps to avoid boundary shortcuts.

71 lines without explanation
1import torch
2import torch.nn as nn
3import numpy as np
4from itertools import permutations
5
6class JigsawNet(nn.Module):
7    """Jigsaw puzzle solving for self-supervised learning."""
8
9    def __init__(self, num_permutations=100):
10        super().__init__()
11        # Siamese network: process each patch with shared weights
12        self.patch_encoder = self._create_encoder()
13        # Classifier predicts permutation class
14        self.classifier = nn.Sequential(
15            nn.Linear(512 * 9, 4096),  # 9 patches × 512 features
16            nn.ReLU(),
17            nn.Dropout(0.5),
18            nn.Linear(4096, num_permutations)
19        )
20        # Generate permutation set (subset of 9! = 362880)
21        self.permutations = self._generate_permutations(num_permutations)
22
23    def _create_encoder(self):
24        """Create patch encoder using AlexNet-style architecture."""
25        return nn.Sequential(
26            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
27            nn.ReLU(inplace=True),
28            nn.MaxPool2d(kernel_size=3, stride=2),
29            nn.Conv2d(64, 192, kernel_size=5, padding=2),
30            nn.ReLU(inplace=True),
31            nn.MaxPool2d(kernel_size=3, stride=2),
32            nn.AdaptiveAvgPool2d((1, 1)),
33            nn.Flatten(),
34            nn.Linear(192, 512)
35        )
36
37    def _generate_permutations(self, num_perms):
38        """Generate maximally different permutation set."""
39        all_perms = list(permutations(range(9)))
40        # Select subset with maximum Hamming distance
41        selected = [all_perms[0]]  # Start with identity
42        while len(selected) < num_perms:
43            max_dist = -1
44            best_perm = None
45            for perm in all_perms:
46                if perm not in selected:
47                    min_dist = min(
48                        sum(p1 != p2 for p1, p2 in zip(perm, s))
49                        for s in selected
50                    )
51                    if min_dist > max_dist:
52                        max_dist = min_dist
53                        best_perm = perm
54            selected.append(best_perm)
55        return torch.tensor(selected)
56
57    def forward(self, patches):
58        # patches: (batch, 9, 3, patch_h, patch_w)
59        batch_size = patches.size(0)
60        features = []
61        for i in range(9):
62            patch_feat = self.patch_encoder(patches[:, i])
63            features.append(patch_feat)
64        # Concatenate all patch features
65        combined = torch.cat(features, dim=1)
66        return self.classifier(combined)
67
68def extract_patches(image, grid_size=3, patch_size=64):
69    """Extract grid of patches from image."""
70    patches = []
71    h_step = image.shape[1] // grid_size
72    w_step = image.shape[2] // grid_size
73    for i in range(grid_size):
74        for j in range(grid_size):
75            patch = image[:, i*h_step:(i+1)*h_step, j*w_step:(j+1)*w_step]
76            # Resize to standard patch size
77            patch = F.interpolate(patch.unsqueeze(0), size=patch_size)[0]
78            patches.append(patch)
79    return torch.stack(patches)

Permutation Selection

The choice of permutation set matters. Selecting permutations with maximum Hamming distance from each other ensures the task is neither too easy (similar permutations) nor arbitrary (random subset). This forces the network to learn fine-grained spatial reasoning.

Colorization

Colorization as a pretext task was explored by Zhang et al. (2016) and Larsson et al. (2016). The task is intuitive: given a grayscale image (L channel in Lab color space), predict the color channels (ab channels).

Why Colorization Works

To colorize an image correctly, the network must learn:

  • Object recognition: Grass is green, sky is blue, skin has characteristic tones
  • Scene understanding: Indoor vs outdoor, day vs night
  • Texture semantics: Wood grain, fabric patterns, metal surfaces
  • Context: A ball on grass is likely not purple

Mathematical Formulation

In Lab color space, L represents luminance (grayscale) and ab represents color. The colorization task can be formulated as either regression or classification:

Regression approach:

Lcolor-reg=Ex[fθ(L(x))ab(x)2]\mathcal{L}_{\text{color-reg}} = \mathbb{E}_{x} \left[ \| f_\theta(L(x)) - ab(x) \|^2 \right]

Classification approach (quantize ab space into bins):

Lcolor-cls=Ex[h,wlogpθ(c^h,wL(x))]\mathcal{L}_{\text{color-cls}} = -\mathbb{E}_{x} \left[ \sum_{h,w} \log p_\theta(\hat{c}_{h,w} | L(x)) \right]

The classification approach handles the inherent multimodality of colorization better—a dress could legitimately be red, blue, or black.

Interactive Demo

Explore colorization as a pretext task. Observe how the network must understand scene semantics to predict plausible colors:

Colorization Pretext Task

Original (Hidden)
Grayscale Input
Predicted Colors
Prediction Accuracy:70%
Colorization Pipeline
L channel
Grayscale
CNN
Encoder-Decoder
ab channels
Color
Training Objective:
L = ||f(IL) - Iab||² + λ · Cross-Entropy(ĉ, c)

• First term: Pixel-wise reconstruction loss (L2)

• Second term: Classification loss for color bins

• Color space: Lab (L = luminance, ab = color)

Why Colorization Works:
  1. Network must understand what objects are present
  2. Must recognize object boundaries and textures
  3. Learns semantic features (grass is green, sky is blue)
  4. Cross-entropy loss treats it as classification problem
  5. Learns rich, transferable visual representations
Challenge:

Colorization is an ill-posed problem - many valid colorizations exist for a single grayscale image. This ambiguity actually helps the network learn more robust features!

Implementation

Colorization Network
🐍colorization.py
5Colorization Network

Encoder-decoder architecture that predicts color (ab channels) from grayscale (L channel).

8Color Bins

313 bins quantize the ab color space. Classification approach handles multimodal colors better than regression.

11Single Channel Input

Input is the L (luminance) channel only - grayscale information.

16Dilated Convolutions

Dilated convolutions increase receptive field without losing resolution - important for global context.

23Color Classification

Output 313 classes representing quantized ab color bins. Each pixel classified independently.

26Regression Alternative

Direct ab prediction (2 channels) is simpler but produces desaturated results.

47Lab Color Space

Lab separates luminance (L) from color (ab), making colorization well-defined.

52Normalization

Normalize L to [0,1] and ab to [0,1] for neural network training.

61Loss Function

L2 loss for smooth predictions. Classification losses like cross-entropy also work well.

64 lines without explanation
1import torch
2import torch.nn as nn
3from skimage.color import rgb2lab, lab2rgb
4
5class ColorizationNet(nn.Module):
6    """Colorization network: predict ab channels from L channel."""
7
8    def __init__(self, num_bins=313):
9        super().__init__()
10        # Encoder: downsample and extract features
11        self.encoder = nn.Sequential(
12            self._conv_block(1, 64, 3, 1, 1),     # L channel input
13            self._conv_block(64, 128, 3, 2, 1),
14            self._conv_block(128, 256, 3, 2, 1),
15            self._conv_block(256, 512, 3, 2, 1),
16            self._conv_block(512, 512, 3, 1, 2, dilation=2),
17        )
18        # Decoder: upsample to predict colors
19        self.decoder = nn.Sequential(
20            self._deconv_block(512, 256),
21            self._deconv_block(256, 128),
22            self._deconv_block(128, 64),
23            nn.Conv2d(64, num_bins, kernel_size=1),  # Classify color bins
24        )
25        # Or for regression: predict ab directly
26        self.ab_regression = nn.Conv2d(64, 2, kernel_size=1)
27
28    def _conv_block(self, in_ch, out_ch, k, s, p, dilation=1):
29        return nn.Sequential(
30            nn.Conv2d(in_ch, out_ch, k, s, p, dilation=dilation),
31            nn.BatchNorm2d(out_ch),
32            nn.ReLU(inplace=True)
33        )
34
35    def _deconv_block(self, in_ch, out_ch):
36        return nn.Sequential(
37            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
38            nn.BatchNorm2d(out_ch),
39            nn.ReLU(inplace=True)
40        )
41
42    def forward(self, L):
43        features = self.encoder(L)
44        ab_pred = self.decoder(features)
45        return ab_pred
46
47def rgb_to_lab_tensors(rgb_images):
48    """Convert RGB images to Lab color space."""
49    # rgb_images: (B, 3, H, W) in [0, 1]
50    batch_size = rgb_images.size(0)
51    L_batch, ab_batch = [], []
52
53    for i in range(batch_size):
54        img = rgb_images[i].permute(1, 2, 0).numpy()  # (H, W, 3)
55        lab = rgb2lab(img)  # L: [0, 100], ab: [-128, 127]
56        L = lab[:, :, 0:1] / 100.0  # Normalize L to [0, 1]
57        ab = (lab[:, :, 1:3] + 128) / 255.0  # Normalize ab to [0, 1]
58        L_batch.append(torch.from_numpy(L).permute(2, 0, 1))
59        ab_batch.append(torch.from_numpy(ab).permute(2, 0, 1))
60
61    return torch.stack(L_batch).float(), torch.stack(ab_batch).float()
62
63# Training with classification loss (better for multimodal colors)
64def colorization_loss(pred_ab, true_ab, class_weights=None):
65    """Combined L2 and classification loss for colorization."""
66    # L2 loss for smooth predictions
67    l2_loss = F.mse_loss(pred_ab, true_ab)
68
69    # Optional: class-rebalancing for rare colors
70    if class_weights is not None:
71        l2_loss = l2_loss * class_weights
72
73    return l2_loss

Class Rebalancing

Natural images have highly imbalanced color distributions—lots of grays, browns, and blues (sky, roads), but few bright colors. Without rebalancing, networks learn to predict "safe" desaturated colors. Reweighting rare colors in the loss function encourages more vibrant predictions.

Context Prediction

Context prediction, introduced by Doersch et al. in 2015, was one of the first successful pretext tasks. Given two patches from the same image, predict their relative spatial position (one of 8 possible locations around a center patch).

The Task

Extract a center patch and one of 8 surrounding patches. The network must classify which of the 8 positions (top-left, top, top-right, left, right, bottom-left, bottom, bottom-right) the second patch came from.

Lcontext=Ex,i,j[logpθ(pos(pj)pi,pj)]\mathcal{L}_{\text{context}} = -\mathbb{E}_{x, i, j} \left[ \log p_\theta(\text{pos}(p_j) | p_i, p_j) \right]

where pip_i is the center patch and pjp_j is a randomly selected neighbor.

Interactive Demo

Test your intuition for spatial relationships. Can you predict where the red patch is relative to the center?

Context Prediction (Relative Patch Location)

Accuracy: 0.0%(0/0)
Image Grid (3×3)
Center Patch
Target Patch
Predict: Where is the red patch relative to the center?
CENTER
How It Works:

Given two patches from the same image, predict their relative position. Network must learn spatial relationships and object structure to succeed. Random chance: 12.5% (1/8 positions).

Avoiding Shortcuts

The Chromatic Aberration Problem

The original context prediction paper discovered that networks could "cheat" using chromatic aberration—color fringes at the edges of images that reveal absolute position. Solutions include:
  • Converting to grayscale or projecting to a different color space
  • Adding random jitter to patch positions
  • Using patches from the center region only

Context Encoders (Inpainting)

Context Encoders by Pathak et al. (2016) take a different approach: mask out a region of an image and train the network to fill it in. This is essentially "inpainting" as a pretext task.

Mathematical Formulation

Let MM be a binary mask indicating the region to be filled. The context encoder learns to reconstruct masked regions:

Linpaint=λrecM(fθ(x(1M))x)2+λadvLadv\mathcal{L}_{\text{inpaint}} = \lambda_{\text{rec}} \cdot \| M \odot (f_\theta(x \odot (1-M)) - x) \|^2 + \lambda_{\text{adv}} \cdot \mathcal{L}_{\text{adv}}

where:

  • \odot denotes element-wise multiplication
  • The first term is reconstruction loss on the masked region
  • Ladv\mathcal{L}_{\text{adv}} is an adversarial loss for realism

Interactive Demo

Explore how inpainting forces the network to understand global context to fill missing regions:

Context Encoder (Inpainting)

Original
Masked Input
Reconstructed
Mask Type:
Mask Size:50%
Quality:0%
Context Encoder Architecture
Input
Masked Image
Encoder
Conv + Pool
z
Latent
Decoder
Deconv
Output
Filled
Training Loss:
L = Lrec + λ · Ladv

Lrec: L2 reconstruction loss on masked region

Ladv: Adversarial loss from discriminator

• Discriminator distinguishes real vs. inpainted regions

Why Inpainting Works for SSL:
  1. Must understand global context to fill missing regions
  2. Learns semantic content (what should be there)
  3. Learns texture and style consistency
  4. Forces learning of object boundaries and shapes
  5. Adversarial training improves realism
Modern Evolution: Masked Autoencoders (MAE)

MAE (2021) extends this idea by masking 75% of image patches and using a Vision Transformer to reconstruct them. This has become one of the most effective SSL methods for vision!

Evolution to Masked Autoencoders

The inpainting idea evolved dramatically with Masked Autoencoders (MAE)in 2021. MAE masks 75% of image patches and uses a Vision Transformer to reconstruct them. This has become one of the most effective self-supervised methods:

AspectContext Encoder (2016)MAE (2021)
Mask ratio~25% (one region)75% (random patches)
ArchitectureCNN encoder-decoderVision Transformer
ReconstructionPixel space + adversarialPixel space only
Downstream performanceGood for detectionState-of-the-art

PyTorch Implementation

Here's a complete training pipeline that combines multiple pretext tasks:

🐍python
1import torch
2import torch.nn as nn
3from torchvision import transforms, datasets
4
5class MultiTaskPretext(nn.Module):
6    """Combined pretext task learning."""
7
8    def __init__(self, backbone):
9        super().__init__()
10        self.backbone = backbone
11        self.rotation_head = nn.Linear(512, 4)
12        self.jigsaw_head = nn.Linear(512 * 9, 100)
13
14    def forward(self, x, task="rotation"):
15        features = self.backbone(x)
16        if task == "rotation":
17            return self.rotation_head(features)
18        elif task == "jigsaw":
19            # Assumes x is already 9 concatenated patch features
20            return self.jigsaw_head(features)
21        return features
22
23def create_pretext_dataloaders(data_path, batch_size=64):
24    """Create dataloaders for pretext training."""
25    transform = transforms.Compose([
26        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
27        transforms.RandomHorizontalFlip(),
28        transforms.ToTensor(),
29        transforms.Normalize(mean=[0.485, 0.456, 0.406],
30                           std=[0.229, 0.224, 0.225])
31    ])
32
33    dataset = datasets.ImageFolder(data_path, transform=transform)
34    loader = torch.utils.data.DataLoader(
35        dataset, batch_size=batch_size, shuffle=True,
36        num_workers=4, pin_memory=True, drop_last=True
37    )
38    return loader
39
40# Transfer learning after pretext training
41def transfer_to_classification(pretrained_backbone, num_classes):
42    """Fine-tune pretrained backbone for classification."""
43    model = nn.Sequential(
44        pretrained_backbone,
45        nn.Linear(512, num_classes)
46    )
47
48    # Optionally freeze backbone for linear probing
49    for param in pretrained_backbone.parameters():
50        param.requires_grad = False
51
52    return model

Comparing Pretext Tasks

Different pretext tasks learn different types of features. Here's how they compare:

Pretext TaskKey FeatureBest ForLimitation
RotationGlobal orientationScene recognitionOrientation-invariant objects
JigsawSpatial relationshipsObject detectionComputational cost (9× patches)
ColorizationSemantic understandingScene classificationDesaturated predictions
ContextLocal spatial reasoningObject partsChromatic aberration shortcuts
InpaintingGlobal contextSegmentationBlurry reconstructions

Transfer Learning Performance

When these pretext tasks were evaluated on downstream tasks (using linear probes or fine-tuning), they showed varying effectiveness:

MethodImageNet Linear (Top-1)VOC07 Classification
Random init~11%~35%
Rotation~48%~67%
Jigsaw~45%~65%
Colorization~40%~62%
Context prediction~35%~55%
Supervised (upper bound)~75%~87%

Historical Context

These results were state-of-the-art for self-supervised learning in 2016-2018. Modern contrastive methods like SimCLR and MoCo achieve 70%+ on ImageNet linear probing, nearly closing the gap with supervised learning.

Limitations and Evolution

Why Pretext Tasks Were Superseded

Despite their success, pretext tasks have fundamental limitations:

  1. Task-specific features: Features may overfit to the pretext task. Rotation prediction might learn chromatic aberration rather than semantics.
  2. Limited supervision signal: A 4-class rotation task provides much less information per image than comparing to thousands of other images.
  3. Shortcut solutions: Networks find unexpected ways to solve tasks without learning meaningful representations.
  4. Design effort: Each new pretext task requires careful design to avoid shortcuts and ensure transferability.

The Contrastive Learning Revolution

Modern self-supervised learning largely moved to contrastive methods that:

  • Compare images to each other rather than solving fixed tasks
  • Create supervision through data augmentation invariance
  • Scale better with more data and larger models
  • Achieve near-supervised performance on benchmarks

We'll explore contrastive learning methods like SimCLR, MoCo, and BYOL in Chapter 25.


Summary

Pretext tasks represent a foundational approach to self-supervised learning that exploits the inherent structure of images to create free supervision. Key takeaways:

  • Pretext tasks automatically generate labels from data structure, enabling learning from unlimited unlabeled images
  • Rotation prediction teaches orientation and scene understanding through 4-way classification
  • Jigsaw puzzles force learning of spatial relationships between object parts
  • Colorization requires semantic understanding to predict plausible colors
  • Context prediction and inpainting leverage spatial context
  • Features learned via pretext tasks transfer well to downstream tasks, though not as well as modern contrastive methods
Historical Significance: While pretext tasks have been largely superseded by contrastive learning, they provided crucial insights: (1) useful features can be learned without labels, (2) the right "proxy task" matters enormously, and (3) avoiding shortcuts is essential. These lessons continue to guide self-supervised learning research.

Knowledge Check

Test your understanding of pretext tasks for images:

Knowledge Check: Pretext Tasks

Score: 0/0
Question 1 of 813% complete

What is the primary purpose of pretext tasks in self-supervised learning?