Chapter 7
22 min read
Section 45 of 178

Data Transforms

Data Loading and Processing

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the mathematical foundations of data transforms, including normalization, standardization, and geometric transformations
  2. Build transform pipelines using transforms.Compose to chain operations efficiently
  3. Apply data augmentation strategies to increase training data diversity and improve generalization
  4. Implement custom transforms for domain-specific preprocessing needs
  5. Choose appropriate transforms for training versus validation/inference
  6. Optimize transform performance using GPU acceleration and efficient implementations
Why This Matters: Neural networks are notoriously sensitive to the scale and distribution of their inputs. A simple change in normalization can mean the difference between a model that converges smoothly and one that diverges catastrophically. Data transforms are the bridge between raw data and model-ready tensors.

The Big Picture

The Problem: Raw Data is Messy

Consider the challenges of training on real-world images:

ChallengeRaw Data ProblemTransform Solution
Inconsistent sizesImages range from 100×100 to 4000×3000Resize, CenterCrop, RandomResizedCrop
Variable scalePixel values 0-255, features span orders of magnitudeNormalize, Standardize, ToTensor
Limited dataOnly 1000 training images for 100 classesRandomFlip, RandomRotation, ColorJitter
Format mismatchPIL Images, numpy arrays, file pathsToTensor, ConvertImageDtype
Domain shiftTraining on lab photos, testing on phone photosColor augmentation, blur, noise injection

Transforms solve all of these problems systematically. They are pure functions that take an input (image, tensor, label) and return a transformed version, composable into pipelines that run efficiently during training.

A Brief History

The importance of data preprocessing was recognized early in machine learning:

  1. 1990s (Traditional ML): Feature standardization and whitening for SVMs and neural networks
  2. 2012 (AlexNet): Random cropping and horizontal flips became standard for ImageNet
  3. 2015 (Batch Normalization): Reduced sensitivity to input scale, but transforms still critical
  4. 2018-2020 (AutoAugment, RandAugment): Learned and random augmentation policies
  5. 2021+ (Modern frameworks): GPU-accelerated transforms, unified APIs for images/video/labels
Key Insight: Data augmentation is one of the most powerful regularization techniques. By artificially expanding your training set with transformed versions of existing samples, you effectively tell the model: "These variations don't change the label, so learn to be invariant to them."

Why Transforms Matter

The Scale Problem

Neural networks learn by adjusting weights through gradient descent. When input features have wildly different scales, gradients become problematic:

🐍scale_problem.py
1# Feature 1: Age (0-100)
2# Feature 2: Income ($0-$1,000,000)
3
4# Without normalization:
5# - Gradient for income is ~10,000x smaller than for age
6# - Optimizer steps are dominated by age
7# - Income feature is essentially ignored!
8
9# With normalization:
10# - Both features have similar gradient magnitudes
11# - Optimizer treats all features equally

Mathematically, consider the loss landscape. For a simple linear model f(x)=w1x1+w2x2f(x) = w_1 x_1 + w_2 x_2, the loss surface is:

L(w1,w2)=1Ni=1N(yiw1xi1w2xi2)2\mathcal{L}(w_1, w_2) = \frac{1}{N} \sum_{i=1}^{N} (y_i - w_1 x_{i1} - w_2 x_{i2})^2

If Var(x1)Var(x2)\text{Var}(x_1) \gg \text{Var}(x_2), the loss surface becomes an elongated ellipse—gradients point nearly perpendicular to the true optimum, requiring many zigzagging steps.

The Distribution Problem

Activation functions have "sweet spots" where they're most effective:

ActivationOptimal Input RangeProblem with Large Inputs
Sigmoid[-4, 4]Saturates to 0 or 1, gradients vanish
Tanh[-2, 2]Saturates to ±1, gradients vanish
ReLU[0, ∞)Large values → large gradients → instability
GELU[-3, 3]Loses non-linearity for very large inputs

Normalization keeps inputs in these effective ranges. For ImageNet, the standard normalization:

xnorm=xμσ=x[0.485,0.456,0.406][0.229,0.224,0.225]x_{\text{norm}} = \frac{x - \mu}{\sigma} = \frac{x - [0.485, 0.456, 0.406]}{[0.229, 0.224, 0.225]}

transforms RGB values from [0, 1] to approximately [-2, 2.5], well within the activation sweet spots.


The Transform Abstraction

In PyTorch, a transform is simply a callable that takes an input and returns a transformed output:

🐍transform_abstraction.py
1from torchvision import transforms
2
3# A transform is just a callable
4normalize = transforms.Normalize(mean=[0.485], std=[0.229])
5
6# Input: tensor of shape [C, H, W]
7# Output: normalized tensor of same shape
8output = normalize(input_tensor)
9
10# Transforms can be chained
11pipeline = transforms.Compose([
12    transforms.Resize(256),
13    transforms.CenterCrop(224),
14    transforms.ToTensor(),
15    transforms.Normalize(mean=[0.485, 0.456, 0.406],
16                        std=[0.229, 0.224, 0.225])
17])
18
19# The entire pipeline is also callable
20output = pipeline(pil_image)

Transform Categories

CategoryPurposeExamples
GeometricChange spatial arrangementResize, Crop, Rotate, Flip, Affine
PhotometricChange pixel valuesColorJitter, Normalize, GaussianBlur
Type ConversionChange data formatToTensor, ConvertImageDtype, PILToTensor
CompositionCombine transformsCompose, RandomApply, RandomChoice
AugmentationAdd training-time variationRandomCrop, RandomRotation, RandAugment

The ToTensor Transform

The most fundamental transform converts PIL Images to PyTorch tensors:

🐍to_tensor.py
1from torchvision.transforms import ToTensor
2
3# What ToTensor does:
4# 1. Converts PIL Image (H, W, C) to Tensor (C, H, W)
5# 2. Scales pixel values from [0, 255] to [0.0, 1.0]
6# 3. Converts to torch.float32
7
8to_tensor = ToTensor()
9
10pil_image  # PIL Image: (224, 224), RGB, values 0-255
11tensor = to_tensor(pil_image)
12# tensor: torch.Tensor, shape [3, 224, 224], values 0.0-1.0

PIL Image Format

PIL Images use (Height, Width, Channels) ordering, but PyTorch tensors use (Channels, Height, Width). The ToTensor transform handles this transposition automatically. Forgetting this conversion leads to cryptic shape mismatch errors.

Interactive: Image Transforms

Explore how different transforms modify an image. Adjust parameters to see how resize, crop, rotation, and color adjustments affect the visual appearance. Notice how some transforms preserve spatial relationships while others dramatically alter the image.

Interactive Image Transform Explorer

Select a transform and see how it affects the image. Click "Random" to simulate what happens during training.

Select Transform

Original

No transformation applied

# No transform
image

Original

Transformed

Why This Matters

Transforms artificially expand your training set. A model that only sees one version of each image will overfit to those exact pixel patterns.

Key observations from the visualizer:

  • Resize: Changes dimensions, may distort aspect ratio without care
  • CenterCrop: Extracts a fixed region from the center, discards borders
  • RandomCrop: Same as CenterCrop but at random positions (different each call)
  • Rotation: Rotates image, may introduce black corners or require cropping
  • ColorJitter: Randomly adjusts brightness, contrast, saturation, hue

Quick Check

If you apply RandomHorizontalFlip with p=0.5 to a batch of 100 images, approximately how many will be flipped?


Mathematics of Normalization

Normalization is perhaps the most important transform for training stability. Let's understand it deeply.

Standard Normalization (Z-Score)

Given a feature xx with population mean μ\mu and standard deviation σ\sigma, the z-score normalization is:

z=xμσz = \frac{x - \mu}{\sigma}

This transforms the distribution to have zero mean and unit variance:

E[z]=E[xμσ]=E[x]μσ=μμσ=0\mathbb{E}[z] = \mathbb{E}\left[\frac{x - \mu}{\sigma}\right] = \frac{\mathbb{E}[x] - \mu}{\sigma} = \frac{\mu - \mu}{\sigma} = 0
Var(z)=Var(xμσ)=Var(x)σ2=σ2σ2=1\text{Var}(z) = \text{Var}\left(\frac{x - \mu}{\sigma}\right) = \frac{\text{Var}(x)}{\sigma^2} = \frac{\sigma^2}{\sigma^2} = 1

Per-Channel Normalization for Images

For RGB images, normalization is applied per channel:

xnormc=xcμcσcfor c{R,G,B}x^c_{\text{norm}} = \frac{x^c - \mu_c}{\sigma_c} \quad \text{for } c \in \{R, G, B\}

Where μc\mu_c and σc\sigma_c are the mean and standard deviation of channel cc across the entire training dataset.

ImageNet Statistics

The famous ImageNet normalization constants:

μ=[0.485,0.456,0.406],σ=[0.229,0.224,0.225]\mu = [0.485, 0.456, 0.406], \quad \sigma = [0.229, 0.224, 0.225]

were computed by averaging across all 1.2 million ImageNet training images. These values are now standard for any model pretrained on ImageNet.

🐍normalization.py
1from torchvision import transforms
2
3# ImageNet normalization
4normalize = transforms.Normalize(
5    mean=[0.485, 0.456, 0.406],  # Per-channel means
6    std=[0.229, 0.224, 0.225]   # Per-channel stds
7)
8
9# For a tensor x of shape [C, H, W]:
10# output[c] = (x[c] - mean[c]) / std[c]
11
12# Computing your own dataset statistics:
13def compute_channel_stats(dataloader):
14    """Compute mean and std for a dataset."""
15    mean = 0.0
16    std = 0.0
17    n_samples = 0
18
19    for images, _ in dataloader:
20        # images: [B, C, H, W]
21        batch_samples = images.size(0)
22        images = images.view(batch_samples, images.size(1), -1)
23        mean += images.mean(2).sum(0)
24        std += images.std(2).sum(0)
25        n_samples += batch_samples
26
27    mean /= n_samples
28    std /= n_samples
29    return mean, std

Why Not Min-Max Scaling?

An alternative normalization is min-max scaling:

xscaled=xxminxmaxxminx_{\text{scaled}} = \frac{x - x_{\min}}{x_{\max} - x_{\min}}

This maps values to [0, 1], but has problems:

AspectZ-Score NormalizationMin-Max Scaling
RangeApproximately [-3, 3]Exactly [0, 1]
Outlier sensitivityLow (just shifts tail)High (single outlier changes all)
Information preservationPreserves relative distancesCompresses if outliers exist
For neural networksPreferredAvoid unless bounded outputs needed

When to Use Each

Use z-score normalization for neural network inputs—it's robust to outliers and works well with gradient-based optimization. Use min-max scaling only when you need outputs in a specific range (e.g., for image generation where pixels must be in [0, 255]).

Interactive: Normalization

Visualize how normalization transforms image pixel distributions. Compare raw pixel values to normalized values and see how different normalization parameters affect the output distribution.

How Normalization Works

Adjust the RGB values to see how ToTensor() and Normalize() transform pixel values step by step.

Input Pixel Value (0-255)

128
100
200
Input Color
RGB(128, 100, 200)
#8064c8

ImageNet Statistics

Mean
R:0.485
G:0.456
B:0.406
Std Dev
R:0.229
G:0.224
B:0.225
1

Raw Pixel Value

# uint8, range [0, 255]
128100200
2

ToTensor()

# float32, range [0, 1]
# Formula: x / 255
0.50200.39220.7843
128 / 255 = 0.5020
3

Normalize(mean, std)

# float32, approximately N(0, 1)
# Formula: (x - mean) / std
+0.0741-0.2850+1.6814
(0.5020 - 0.485) / 0.229 = 0.0741

Value Ranges

0Raw [0-255]255
0.0Tensor [0-1]1.0
-3.0Normalized (centered)+3.0

Why Normalize?

Stable Gradients

Centered values (mean ≈ 0) prevent gradient explosion/vanishing in deep networks.

Pretrained Compatibility

ImageNet-pretrained models expect ImageNet-normalized inputs. Using wrong statistics hurts performance.

Faster Convergence

Normalized inputs help optimizers find good solutions faster with consistent learning rates.

Notice how normalization:

  • Centers the distribution around zero (important for sigmoid/tanh activations)
  • Scales the variance to approximately 1 (prevents gradient explosion/vanishing)
  • Operates per-channel (R, G, B are normalized independently)
  • Uses dataset statistics (not individual image statistics)

Quick Check

Why do we use dataset-wide mean and std for normalization, rather than computing them per-image?


Composing Transforms

Real data pipelines chain multiple transforms. The transforms.Compose class creates a callable that applies transforms sequentially:

Composing a Training Transform Pipeline
🐍compose_example.py
6RandomResizedCrop

The workhorse of ImageNet training. Randomly crops a region (8-100% of original area) and resizes to 224×224. This simultaneously provides scale augmentation and size normalization.

13RandomHorizontalFlip

Mirrors the image horizontally with 50% probability. This is valid for natural images (a cat facing left is still a cat) but would break text recognition tasks.

16ColorJitter

Randomly adjusts brightness, contrast, saturation, and hue. Makes the model invariant to lighting conditions and camera settings. Values are percentage variations from original.

24ToTensor

Converts PIL Image to PyTorch tensor and scales from [0, 255] to [0, 1]. Must come before Normalize, which expects tensor input.

27Normalize

Applies z-score normalization with ImageNet statistics. Must be the last transform since it operates on tensor values.

30 lines without explanation
1from torchvision import transforms
2
3# A typical ImageNet training pipeline
4train_transform = transforms.Compose([
5    # Step 1: Resize with random crop for scale augmentation
6    transforms.RandomResizedCrop(
7        size=224,
8        scale=(0.08, 1.0),    # Crop 8% to 100% of original area
9        ratio=(3/4, 4/3),     # Aspect ratio range
10    ),
11
12    # Step 2: Random horizontal flip (50% probability)
13    transforms.RandomHorizontalFlip(p=0.5),
14
15    # Step 3: Color augmentation
16    transforms.ColorJitter(
17        brightness=0.4,
18        contrast=0.4,
19        saturation=0.4,
20        hue=0.1
21    ),
22
23    # Step 4: Convert PIL Image to Tensor
24    transforms.ToTensor(),
25
26    # Step 5: Normalize using ImageNet statistics
27    transforms.Normalize(
28        mean=[0.485, 0.456, 0.406],
29        std=[0.229, 0.224, 0.225]
30    ),
31])
32
33# Apply the entire pipeline
34tensor = train_transform(pil_image)
35# Returns: torch.Tensor of shape [3, 224, 224]

Transform Ordering Matters

The order of transforms is critical:

Correct OrderWhy This Order
Geometric → Photometric → ToTensor → NormalizePIL operations before tensor conversion
RandomCrop before ResizeCrop first if you want random regions
Resize before RandomCropResize first if you want fixed-size crops
ToTensor before NormalizeNormalize operates on tensors, not PIL images
Normalize lastOther tensor ops expect original scale

Common Ordering Mistake

Applying Normalize before ToTensor will fail because Normalize expects a tensor input, not a PIL Image. Similarly, applying geometric transforms after ToTensor requires using the functional API or tensor-specific transforms.

Interactive: Transform Pipeline

Build your own transform pipeline by selecting and ordering transforms. Watch how each step modifies the image as it flows through the pipeline. Experiment with different orderings to understand why transform order matters.

Transform Pipeline Builder

Build a data augmentation pipeline. Watch how each transform modifies the image in sequence.

Presets:
Input
RandomCrop
HFlip
ColorJitter
ToTensor
Normalize
Output
Add:

Live Preview

Sample Image
PIL Image
Range [0,255]

Generated PyTorch Code

from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

Pipeline Best Practices

  • • Always end with ToTensor() and Normalize()
  • • Use RandomCrop for training, CenterCrop for validation
  • • Color augmentations should come before ToTensor()
  • • Order matters: spatial transforms usually come first

Key insights from building pipelines:

  • Early augmentation affects all subsequent transforms
  • Deterministic transforms (Resize, CenterCrop) should come before random ones for consistency
  • ToTensor is the conversion boundary between PIL and tensor operations
  • Heavy augmentation during training, minimal transforms during validation

Data Augmentation

Data augmentation artificially expands the training set by applying random transformations to existing samples. This is one of the most effective regularization techniques in deep learning.

The Invariance Principle

Augmentation encodes prior knowledge about invariances:

  • Horizontal flip: A cat facing left is still a cat
  • Random crop: Object identity doesn't depend on exact position
  • Color jitter: A red car under bright sun is still a car
  • Scale variation: An elephant close up is still an elephant

By training on augmented data, the model learns to be invariant to these transformations.

Standard Augmentations by Domain

DomainStandard AugmentationsWhy These Work
Image ClassificationRandomResizedCrop, HorizontalFlip, ColorJitterNatural images have these invariances
Object DetectionRandomCrop, Scale, Flip + bbox adjustmentsObjects can appear at any position/scale
Semantic SegmentationSame transforms to image AND maskPixel labels must stay aligned
Medical ImagingRotation, Elastic deformation, intensity shiftOrgans appear at various orientations
Text/OCRRandom scale, slight rotation, noiseBut NOT horizontal flip (reverses text)

Common Augmentation Transforms

🐍common_augmentations.py
1from torchvision import transforms
2
3# Geometric augmentations
4transforms.RandomHorizontalFlip(p=0.5)        # Mirror horizontally
5transforms.RandomVerticalFlip(p=0.5)          # Mirror vertically
6transforms.RandomRotation(degrees=15)         # Rotate ±15°
7transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))  # Shift up to 10%
8transforms.RandomPerspective(distortion_scale=0.2)  # Perspective warp
9transforms.RandomResizedCrop(224, scale=(0.8, 1.0))  # Random scale + crop
10
11# Photometric augmentations
12transforms.ColorJitter(brightness=0.2, contrast=0.2)
13transforms.GaussianBlur(kernel_size=3)
14transforms.RandomGrayscale(p=0.1)             # 10% chance of grayscale
15transforms.RandomAutocontrast(p=0.5)
16transforms.RandomEqualize(p=0.5)
17
18# Composition helpers
19transforms.RandomApply([transform], p=0.5)    # Apply with probability
20transforms.RandomChoice([t1, t2, t3])         # Pick one randomly
21transforms.RandomOrder([t1, t2, t3])          # Random ordering

Advanced Augmentation Strategies

RandAugment

RandAugment randomly selects N transforms from a pool and applies them with magnitude M:

🐍randaugment.py
1from torchvision.transforms import autoaugment
2
3# RandAugment with N=2 transforms, magnitude M=9
4randaugment = autoaugment.RandAugment(
5    num_ops=2,           # Apply 2 random transforms
6    magnitude=9,         # Strength 0-30 (9 is standard)
7    interpolation=transforms.InterpolationMode.BILINEAR
8)
9
10train_transform = transforms.Compose([
11    transforms.Resize(256),
12    transforms.RandomCrop(224),
13    randaugment,         # Apply RandAugment
14    transforms.ToTensor(),
15    transforms.Normalize(mean, std)
16])

Cutout / Random Erasing

Randomly erases rectangular regions of the input image, forcing the model to focus on the entire image rather than a single discriminative region:

🐍cutout.py
1from torchvision.transforms import RandomErasing
2
3# Randomly erase a patch after ToTensor
4random_erasing = RandomErasing(
5    p=0.5,                # 50% probability
6    scale=(0.02, 0.33),   # Erase 2-33% of image area
7    ratio=(0.3, 3.3),     # Aspect ratio of erased region
8    value='random'        # Fill with random values
9)
10
11train_transform = transforms.Compose([
12    transforms.RandomResizedCrop(224),
13    transforms.RandomHorizontalFlip(),
14    transforms.ToTensor(),
15    transforms.Normalize(mean, std),
16    random_erasing         # Must be after ToTensor (operates on tensors)
17])

MixUp and CutMix

These augmentations operate on pairs of samples, creating interpolated training examples:

🐍mixup_cutmix.py
1import torch
2
3def mixup(images, labels, alpha=0.2):
4    """MixUp: linearly interpolate pairs of samples.
5
6    λ ~ Beta(α, α)
7    x' = λ * x_i + (1-λ) * x_j
8    y' = λ * y_i + (1-λ) * y_j
9    """
10    lam = torch.distributions.Beta(alpha, alpha).sample()
11    batch_size = images.size(0)
12    index = torch.randperm(batch_size)
13
14    mixed_images = lam * images + (1 - lam) * images[index]
15    labels_a, labels_b = labels, labels[index]
16
17    return mixed_images, labels_a, labels_b, lam
18
19
20def cutmix(images, labels, alpha=1.0):
21    """CutMix: replace a rectangular region with another sample.
22
23    Region size determined by λ ~ Beta(α, α)
24    """
25    lam = torch.distributions.Beta(alpha, alpha).sample()
26    batch_size = images.size(0)
27    index = torch.randperm(batch_size)
28
29    # Compute region to cut
30    W, H = images.shape[2], images.shape[3]
31    cut_w = int(W * torch.sqrt(1 - lam))
32    cut_h = int(H * torch.sqrt(1 - lam))
33
34    # Random position
35    cx = torch.randint(0, W, (1,)).item()
36    cy = torch.randint(0, H, (1,)).item()
37    x1, x2 = max(0, cx - cut_w // 2), min(W, cx + cut_w // 2)
38    y1, y2 = max(0, cy - cut_h // 2), min(H, cy + cut_h // 2)
39
40    # Cut and paste
41    mixed_images = images.clone()
42    mixed_images[:, :, x1:x2, y1:y2] = images[index, :, x1:x2, y1:y2]
43
44    # Adjust lambda to actual area ratio
45    lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
46
47    return mixed_images, labels, labels[index], lam

Mathematics of Augmentation

Let's formalize why augmentation works from a theoretical perspective.

Augmentation as Regularization

Consider augmentation as expanding the training set:

Daug={(T(xi),yi)(xi,yi)D,TT}\mathcal{D}_{\text{aug}} = \{(T(x_i), y_i) \mid (x_i, y_i) \in \mathcal{D}, T \in \mathcal{T}\}

where T\mathcal{T} is the set of valid augmentation transforms. Training on Daug\mathcal{D}_{\text{aug}} is equivalent to adding a regularization term:

Laug(θ)=ETT[L(fθ(T(x)),y)]\mathcal{L}_{\text{aug}}(\theta) = \mathbb{E}_{T \sim \mathcal{T}} \left[ \mathcal{L}(f_\theta(T(x)), y) \right]

This expectation encourages the model to be invariant to transforms in T\mathcal{T}.

Geometric Transform Mathematics

Most geometric transforms can be expressed as affine transformations:

[xy1]=[abtxcdty001][xy1]\begin{bmatrix} x' \\ y' \\ 1 \end{bmatrix} = \begin{bmatrix} a & b & t_x \\ c & d & t_y \\ 0 & 0 & 1 \end{bmatrix} \begin{bmatrix} x \\ y \\ 1 \end{bmatrix}

Common transforms and their matrices:

TransformMatrix
Translation (tx, ty)[[1, 0, tx], [0, 1, ty], [0, 0, 1]]
Scaling (sx, sy)[[sx, 0, 0], [0, sy, 0], [0, 0, 1]]
Rotation (θ)[[cos θ, -sin θ, 0], [sin θ, cos θ, 0], [0, 0, 1]]
Horizontal Flip[[-1, 0, W], [0, 1, 0], [0, 0, 1]]
Shear (k)[[1, k, 0], [0, 1, 0], [0, 0, 1]]

The RandomAffine transform composes these operations:

🐍affine_math.py
1import torch
2import torch.nn.functional as F
3
4def apply_affine_transform(image, angle, translate, scale, shear):
5    """Apply affine transformation to an image tensor.
6
7    Args:
8        image: [B, C, H, W] tensor
9        angle: rotation angle in degrees
10        translate: (tx, ty) translation as fraction of image size
11        scale: scale factor
12        shear: shear angle in degrees
13    """
14    theta = torch.deg2rad(torch.tensor(angle))
15    shear_rad = torch.deg2rad(torch.tensor(shear))
16
17    # Build transformation matrix
18    # Rotation
19    R = torch.tensor([
20        [torch.cos(theta), -torch.sin(theta)],
21        [torch.sin(theta), torch.cos(theta)]
22    ])
23
24    # Scale
25    S = scale * torch.eye(2)
26
27    # Shear
28    Sh = torch.tensor([
29        [1, torch.tan(shear_rad)],
30        [0, 1]
31    ])
32
33    # Combined: Scale @ Shear @ Rotation
34    M = S @ Sh @ R
35
36    # Add translation
37    t = torch.tensor([translate[0], translate[1]])
38
39    # Affine grid sampling
40    theta_matrix = torch.zeros(1, 2, 3)
41    theta_matrix[0, :2, :2] = M
42    theta_matrix[0, :2, 2] = t
43
44    grid = F.affine_grid(theta_matrix, image.size(), align_corners=False)
45    return F.grid_sample(image, grid, align_corners=False)

Color Augmentation Mathematics

Color augmentations modify pixel values while preserving spatial structure:

Brightness: I=I(1+β),βU(b,b)\text{Brightness: } \quad I' = I \cdot (1 + \beta), \quad \beta \sim U(-b, b)
Contrast: I=μ+(Iμ)(1+c),cU(cmax,cmax)\text{Contrast: } \quad I' = \mu + (I - \mu) \cdot (1 + c), \quad c \sim U(-c_{\max}, c_{\max})
Saturation (HSV): S=S(1+s),sU(smax,smax)\text{Saturation (HSV): } \quad S' = S \cdot (1 + s), \quad s \sim U(-s_{\max}, s_{\max})
Hue (HSV): H=(H+h)mod360,hU(hmax,hmax)\text{Hue (HSV): } \quad H' = (H + h) \mod 360, \quad h \sim U(-h_{\max}, h_{\max})

Functional Transforms

PyTorch offers two APIs for transforms: class-based and functional.

🐍functional_transforms.py
1import torchvision.transforms.functional as F
2from torchvision import transforms
3
4# Class-based: stateful, composable
5resize = transforms.Resize((224, 224))
6output = resize(image)
7
8# Functional: pure functions, more control
9output = F.resize(image, size=(224, 224))
10
11# Functional is useful when you need:
12# 1. Same random parameters for image and label
13# 2. Conditional transforms
14# 3. Custom transform logic
15
16def synchronized_transform(image, mask):
17    """Apply same random transform to image and mask."""
18    # Get random parameters
19    angle = random.uniform(-30, 30)
20    i, j, h, w = transforms.RandomCrop.get_params(image, (200, 200))
21
22    # Apply with same parameters
23    image = F.rotate(image, angle)
24    mask = F.rotate(mask, angle)
25
26    image = F.crop(image, i, j, h, w)
27    mask = F.crop(mask, i, j, h, w)
28
29    return image, mask

When to Use Functional API

Use CaseClass-BasedFunctional
Simple pipelines✓ transforms.Compose
Synchronized image+mask✓ Same parameters
Conditional transforms✓ Full control flow
Custom augmentations✓ Build from primitives
DataLoader integration✓ Direct useWrap in custom class

Modern Transforms (v2)

PyTorch Vision 0.15+ introduced torchvision.transforms.v2, a modernized API with significant improvements:

🐍transforms_v2.py
1from torchvision.transforms import v2
2
3# v2 transforms work with multiple inputs simultaneously
4transform = v2.Compose([
5    v2.RandomResizedCrop(224, antialias=True),
6    v2.RandomHorizontalFlip(p=0.5),
7    v2.ColorJitter(brightness=0.2, contrast=0.2),
8    v2.ToDtype(torch.float32, scale=True),  # Replaces ToTensor
9    v2.Normalize(mean=[0.485, 0.456, 0.406],
10                 std=[0.229, 0.224, 0.225])
11])
12
13# Key improvements in v2:
14
15# 1. Unified handling of images + bounding boxes + masks
16images, bboxes, masks = transform(images, bboxes, masks)
17
18# 2. Built-in support for batched inputs
19batch_images = transform(batch_images)
20
21# 3. GPU acceleration (via torchscript)
22transform = torch.jit.script(transform)
23
24# 4. Better performance with TVTensor types
25from torchvision import tv_tensors
26image = tv_tensors.Image(image_tensor)
27bbox = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
28
29# 5. MixUp and CutMix as transforms
30mixup = v2.MixUp(alpha=0.2, num_classes=1000)
31cutmix = v2.CutMix(alpha=1.0, num_classes=1000)
32mixup_cutmix = v2.RandomChoice([mixup, cutmix])

Migration to v2

The v2 API is backward compatible—existing pipelines will work. However, for new projects, especially object detection or segmentation, use v2 for synchronized transforms across images and labels.

Custom Transforms

When built-in transforms don't meet your needs, create custom ones:

Creating Custom Transforms
🐍custom_transforms.py
4Class-based Custom Transform

Custom transforms are classes with __init__ (for parameters) and __call__ (for transformation logic). This pattern integrates with Compose.

17Tensor Operations

For transforms after ToTensor, work with PyTorch tensors. randn_like creates noise matching the input shape. clamp keeps values in valid range.

20__repr__ for Debugging

Implementing __repr__ helps when printing the transform pipeline. Print the pipeline to verify your transforms are correctly configured.

30Probabilistic Transforms

Many augmentations apply with a probability. Use torch.rand() to randomly skip the transform, preserving some original samples.

42PIL-based Transform

Transforms before ToTensor receive PIL Images. Use PIL operations or convert to tensor temporarily. This example simulates JPEG artifacts.

67 lines without explanation
1import torch
2from torchvision.transforms import functional as F
3
4class AddGaussianNoise:
5    """Add Gaussian noise to a tensor image.
6
7    Useful for:
8    - Making models robust to sensor noise
9    - Regularization through input perturbation
10    - Simulating low-light conditions
11    """
12
13    def __init__(self, mean: float = 0.0, std: float = 0.1):
14        self.mean = mean
15        self.std = std
16
17    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
18        noise = torch.randn_like(tensor) * self.std + self.mean
19        return torch.clamp(tensor + noise, 0.0, 1.0)
20
21    def __repr__(self) -> str:
22        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
23
24
25class RandomChannelShuffle:
26    """Randomly shuffle color channels.
27
28    A cat is still a cat whether viewed in RGB, BGR, or GBR.
29    """
30
31    def __init__(self, p: float = 0.5):
32        self.p = p
33
34    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
35        if torch.rand(1).item() < self.p:
36            perm = torch.randperm(tensor.size(0))
37            return tensor[perm]
38        return tensor
39
40
41class RandomJPEGCompression:
42    """Simulate JPEG compression artifacts.
43
44    Useful for training models robust to image compression.
45    """
46
47    def __init__(self, quality_range: tuple = (30, 95)):
48        self.quality_range = quality_range
49
50    def __call__(self, image):
51        import io
52        from PIL import Image
53
54        # Random quality
55        quality = torch.randint(*self.quality_range, (1,)).item()
56
57        # Save and reload with compression
58        buffer = io.BytesIO()
59        image.save(buffer, format='JPEG', quality=quality)
60        buffer.seek(0)
61
62        return Image.open(buffer)
63
64
65# Usage in a pipeline
66train_transform = transforms.Compose([
67    transforms.RandomResizedCrop(224),
68    RandomJPEGCompression(quality_range=(50, 95)),  # Custom transform
69    transforms.ToTensor(),
70    AddGaussianNoise(std=0.05),                     # Custom transform
71    transforms.Normalize(mean, std)
72])

Transform Best Practices

  • Make transforms deterministic per-call (random within a call is fine, but same seed should reproduce)
  • Implement __repr__ for debugging
  • Handle both PIL Images and Tensors if your transform might be used in different pipeline positions
  • Document what input types are expected (PIL, tensor, specific shape)

Training vs Validation Transforms

A critical distinction: training transforms include randomness for augmentation, while validation transforms must be deterministic for reproducible evaluation.

Training vs Validation Transforms
🐍train_val_transforms.py
7Training: RandomResizedCrop

Crops a random region (8-100% of image) and resizes to 224×224. Each epoch sees different crops of the same image, providing data augmentation.

19Validation: Resize + CenterCrop

Deterministic preprocessing: resize so shortest edge is 256px, then crop the center 224×224 region. Every evaluation of the same image produces identical input.

35Test-Time Augmentation

A technique to improve test accuracy: run the model on multiple augmented versions of the same image and average predictions. Typically gives 1-2% accuracy boost.

38 lines without explanation
1from torchvision import transforms
2
3# ImageNet standard preprocessing
4
5train_transform = transforms.Compose([
6    # Augmentation: random scale, crop, flip
7    transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
8    transforms.RandomHorizontalFlip(),
9    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
10    transforms.RandomGrayscale(p=0.1),
11
12    # Conversion and normalization
13    transforms.ToTensor(),
14    transforms.Normalize(mean=[0.485, 0.456, 0.406],
15                        std=[0.229, 0.224, 0.225])
16])
17
18val_transform = transforms.Compose([
19    # Deterministic: resize to slightly larger, center crop
20    transforms.Resize(256),
21    transforms.CenterCrop(224),
22
23    # Same conversion and normalization
24    transforms.ToTensor(),
25    transforms.Normalize(mean=[0.485, 0.456, 0.406],
26                        std=[0.229, 0.224, 0.225])
27])
28
29# Use different transforms for train and val datasets
30train_dataset = ImageFolder('train/', transform=train_transform)
31val_dataset = ImageFolder('val/', transform=val_transform)
32
33# Test-Time Augmentation (TTA): use train transforms at test time
34# Average predictions over multiple augmented versions
35def predict_with_tta(model, image, n_augments=5):
36    predictions = []
37    for _ in range(n_augments):
38        augmented = train_transform(image)
39        pred = model(augmented.unsqueeze(0))
40        predictions.append(pred)
41    return torch.stack(predictions).mean(dim=0)

Never Use Random Transforms During Validation

Random augmentations during validation produce inconsistent metrics. Your validation loss will vary between runs even with the same model, making it impossible to compare experiments or detect overfitting.

Quick Check

Why does the validation transform use Resize(256) followed by CenterCrop(224) instead of just Resize(224)?


Performance Optimization

Transforms can become a bottleneck. Here are optimization strategies:

1. GPU Acceleration

🐍gpu_transforms.py
1import kornia  # GPU-accelerated vision operations
2
3# kornia transforms run on GPU
4gpu_transform = torch.nn.Sequential(
5    kornia.augmentation.RandomResizedCrop((224, 224), scale=(0.08, 1.0)),
6    kornia.augmentation.RandomHorizontalFlip(p=0.5),
7    kornia.augmentation.ColorJitter(0.4, 0.4, 0.4, 0.1),
8    kornia.augmentation.Normalize(mean, std)
9)
10
11# Apply to batch on GPU
12images = images.to('cuda')
13augmented = gpu_transform(images)
14
15# Faster than CPU transforms for large batches

2. Prefetching and Caching

🐍prefetch_cache.py
1# For datasets that are read repeatedly, cache transforms
2class CachedDataset(Dataset):
3    def __init__(self, dataset, cache_transforms=True):
4        self.dataset = dataset
5        self.cache = {} if cache_transforms else None
6
7    def __getitem__(self, idx):
8        if self.cache is not None and idx in self.cache:
9            return self.cache[idx]
10
11        sample = self.dataset[idx]
12
13        if self.cache is not None:
14            self.cache[idx] = sample
15
16        return sample
17
18# Note: Only cache deterministic transforms (validation)
19# Never cache random augmentations!

3. Optimize Transform Choice

TransformRelative SpeedNotes
Resize (PIL)1.0x (baseline)Uses LANCZOS by default
Resize (bilinear)2.5x fasterGood quality, much faster
RandomCrop~5x faster than RandomResizedCropNo resize step
ColorJitterSlow (4 operations)Consider limiting parameters
ToTensorFastJust memory copy + scale
NormalizeFastElement-wise operations

4. Profiling Transform Pipelines

🐍profile_transforms.py
1import time
2from PIL import Image
3
4def profile_transforms(transform_pipeline, image, n_runs=100):
5    """Profile transform execution time."""
6    # Warmup
7    for _ in range(10):
8        _ = transform_pipeline(image.copy())
9
10    # Measure
11    times = []
12    for _ in range(n_runs):
13        start = time.perf_counter()
14        _ = transform_pipeline(image.copy())
15        times.append(time.perf_counter() - start)
16
17    mean_time = sum(times) / len(times)
18    print(f"Mean transform time: {mean_time*1000:.2f}ms")
19    print(f"Throughput: {1/mean_time:.1f} images/sec")
20
21# Profile your pipeline
22image = Image.open("test.jpg")
23profile_transforms(train_transform, image)

Common Pitfalls

1. Wrong Normalization Statistics

🐍wrong_stats.py
1# ❌ WRONG: Using ImageNet stats for non-ImageNet data
2# If your images have different statistics, normalization is suboptimal
3normalize = transforms.Normalize(
4    mean=[0.485, 0.456, 0.406],  # ImageNet means
5    std=[0.229, 0.224, 0.225]   # ImageNet stds
6)
7
8# ✅ CORRECT: Compute your dataset's statistics
9mean, std = compute_channel_stats(your_dataloader)
10normalize = transforms.Normalize(mean=mean.tolist(), std=std.tolist())
11
12# Exception: When using pretrained ImageNet models, use ImageNet stats
13# The model expects inputs normalized with those exact values

2. Forgetting to Disable Augmentation for Validation

🐍val_augment_mistake.py
1# ❌ WRONG: Same transform for train and val
2transform = transforms.Compose([
3    transforms.RandomResizedCrop(224),  # Random!
4    transforms.RandomHorizontalFlip(),  # Random!
5    transforms.ToTensor(),
6    transforms.Normalize(mean, std)
7])
8train_dataset = ImageFolder('train/', transform=transform)
9val_dataset = ImageFolder('val/', transform=transform)  # Bug: random transforms!
10
11# ✅ CORRECT: Separate transforms
12train_transform = transforms.Compose([...random augmentations...])
13val_transform = transforms.Compose([...deterministic only...])

3. Augmentation That Destroys Labels

🐍label_destroying.py
1# ❌ DANGEROUS: Flipping digit images
2# "6" becomes "9", "3" becomes "E"
3transforms.RandomHorizontalFlip()  # Don't use for digit recognition!
4transforms.RandomVerticalFlip()
5
6# ❌ DANGEROUS: Heavy rotation for text/documents
7transforms.RandomRotation(90)  # Text becomes unreadable
8
9# ✅ SAFE: Know your domain
10# For digits: slight rotation, scale, translation
11# For text: small angles, no flip
12# For faces: horizontal flip OK (usually), vertical flip BAD

4. PIL vs Tensor Transform Confusion

🐍pil_tensor_confusion.py
1# ❌ WRONG: Normalize on PIL Image
2transforms.Compose([
3    transforms.Normalize(mean, std),  # ERROR: expects tensor!
4    transforms.ToTensor()
5])
6
7# ❌ WRONG: PIL operations after ToTensor
8transforms.Compose([
9    transforms.ToTensor(),
10    transforms.RandomCrop(224)  # ERROR: expects PIL Image!
11])
12
13# ✅ CORRECT: PIL ops first, then ToTensor, then tensor ops
14transforms.Compose([
15    transforms.RandomCrop(224),    # PIL operation
16    transforms.ToTensor(),         # Convert to tensor
17    transforms.Normalize(mean, std)  # Tensor operation
18])

5. Not Applying Same Transform to Label

🐍unsync_transforms.py
1# ❌ WRONG: For segmentation, transform image but not mask
2class BrokenSegmentationDataset(Dataset):
3    def __getitem__(self, idx):
4        image = Image.open(self.images[idx])
5        mask = Image.open(self.masks[idx])
6
7        image = self.transform(image)  # Random crop at position A
8        mask = transforms.ToTensor()(mask)  # No crop! Misaligned!
9
10        return image, mask
11
12# ✅ CORRECT: Synchronized transforms
13class CorrectSegmentationDataset(Dataset):
14    def __getitem__(self, idx):
15        image = Image.open(self.images[idx])
16        mask = Image.open(self.masks[idx])
17
18        # Get random parameters
19        i, j, h, w = transforms.RandomCrop.get_params(image, (256, 256))
20
21        # Apply same crop to both
22        image = F.crop(image, i, j, h, w)
23        mask = F.crop(mask, i, j, h, w)
24
25        # Or use transforms.v2 for automatic synchronization
26        return image, mask

Summary

Data transforms are the preprocessing layer that converts raw data into model-ready tensors. They solve critical problems: inconsistent sizes, variable scales, limited data, and format mismatches.

Key Concepts

ConceptPurposeKey Insight
NormalizationScale inputs to standard rangeZero mean, unit variance → stable gradients
Geometric TransformsChange spatial arrangementAffine matrices compose efficiently
Data AugmentationExpand training set virtuallyEncodes invariances we want the model to learn
ComposeChain transforms into pipelinesOrder matters: PIL ops → ToTensor → Normalize
Train vs ValDifferent transforms for eachRandom for training, deterministic for validation

Mathematical Foundations

The key equations we covered:

  • Z-score normalization: z=(xμ)/σz = (x - \mu) / \sigma transforms to zero mean, unit variance
  • Per-channel normalization: Each RGB channel normalized independently
  • Affine transforms: Rotation, scaling, translation as matrix operations
  • Augmentation as regularization: Laug=ET[L(f(T(x)),y)]\mathcal{L}_{\text{aug}} = \mathbb{E}_{T}[\mathcal{L}(f(T(x)), y)]

Best Practices

  1. Always normalize inputs—z-score normalization is standard for neural networks
  2. Use ImageNet statistics when working with pretrained models
  3. Match augmentation to domain—horizontal flip for objects, not for text
  4. Separate train and val transforms—random only during training
  5. Synchronize transforms for detection/segmentation—same crop for image and labels
  6. Profile your pipeline—transforms can become the bottleneck

Looking Ahead

In the next section, we'll explore Custom Datasets—building data loaders for your own data formats, from CSV files to complex multi-modal datasets.


Exercises

Conceptual Questions

  1. Explain why z-score normalization is preferred over min-max scaling for neural network inputs. What problems can occur with min-max scaling?
  2. Why do we use per-channel normalization for RGB images rather than normalizing all pixels together? How would the behavior differ?
  3. A colleague proposes using heavy data augmentation during validation to "test the model's robustness." Explain why this is problematic and what the correct approach would be.
  4. For a medical imaging task (classifying X-rays), which augmentations would be appropriate and which would be harmful? Consider rotations, flips, color changes, and scale.

Coding Exercises

  1. Dataset Statistics: Write a function that computes per-channel mean and standard deviation for a custom image dataset. Test it on CIFAR-10 and verify your values match the known statistics.
  2. Custom Augmentation: Implement a RandomGridDistortion transform that warps an image using a random displacement grid. This is useful for simulating lens distortion or training OCR models.
  3. Synchronized Transforms: Create a dataset class for semantic segmentation that applies the same random transforms (crop, flip, rotation) to both the image and its corresponding segmentation mask.

Solution Hints

  • Statistics: Iterate through DataLoader, accumulate running mean/std, don't load all data into memory
  • Grid Distortion: Use F.grid_sample with a displacement field added to the identity grid
  • Synchronized: Use transforms.RandomCrop.get_params() to get random values, then apply to both image and mask

Challenge Exercise

AutoAugment from Scratch: Implement a simplified version of AutoAugment:

  • Define a pool of 15 transforms (rotate, shear, color, etc.)
  • Each transform has a magnitude parameter (0 to 10)
  • Randomly select 2 transforms and random magnitudes for each image
  • Apply them sequentially
  • Compare training curves with and without your augmentation on CIFAR-10

This exercise teaches you to think about augmentation as a search space and understand how modern auto-augmentation methods work.


In the next section, we'll learn to build Custom Datasets for any data format, from simple CSV files to complex multi-modal data requiring special handling.