Chapter 10
18 min read
Section 49 of 76

Data Preprocessing

Dataset Preparation

Learning Objectives

By the end of this section, you will be able to:

  1. Compare normalization strategies ([-1, 1] vs [0, 1]) and understand why [-1, 1] is standard for diffusion
  2. Apply appropriate data augmentation techniques that improve model generalization without corrupting the learning signal
  3. Handle different image resolutions and aspect ratios consistently across your dataset
  4. Implement a complete, production-ready preprocessing pipeline in PyTorch

The Big Picture

Data preprocessing is often underestimated, but it directly impacts training stability and generation quality. A seemingly minor choice like normalizing to [0, 1] instead of [-1, 1] can cause training to fail or produce washed-out images. Understanding the "why" behind preprocessing choices helps you debug issues and adapt to new datasets.

The Core Principle: Preprocessing should make the model's job easier. For diffusion models, this means centering data at zero (matching the Gaussian noise distribution) and ensuring consistent scales across channels.

In this section, we'll build a principled preprocessing pipeline, explaining each choice and its impact on training.


Normalization Strategies

Why [-1, 1] is the Standard

Raw image pixels are typically in [0, 255]. We need to normalize them, but the choice of target range matters significantly for diffusion models:

RangeFormulaProsCons
[0, 1]x / 255Simple, intuitiveMean = 0.5, not centered
[-1, 1](x / 127.5) - 1Centered at 0Standard for diffusion
ImageNet(x - μ) / σDataset-specificComplex, varies by dataset

The key reason [-1, 1] works best for diffusion models is the Gaussian noise process:

  1. The noise ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1) is centered at 0
  2. The final noisy state xT\mathbf{x}_T should be approximately N(0,1)\mathcal{N}(0, 1)
  3. Data centered at 0 mixes naturally with zero-centered noise
🐍python
1import torch
2from torchvision import transforms
3from typing import Tuple
4
5def normalize_to_minus_one_one(x: torch.Tensor) -> torch.Tensor:
6    """
7    Normalize tensor from [0, 1] to [-1, 1].
8
9    ToTensor() already converts [0, 255] -> [0, 1], so we just shift:
10    y = 2x - 1
11    """
12    return x * 2.0 - 1.0
13
14
15def denormalize_from_minus_one_one(x: torch.Tensor) -> torch.Tensor:
16    """
17    Reverse normalization: [-1, 1] -> [0, 1].
18
19    For display/saving: x_display = (x + 1) / 2
20    Then scale to [0, 255] if needed for saving.
21    """
22    return (x + 1.0) / 2.0
23
24
25# Standard approach using transforms.Normalize
26# Normalize(mean, std) computes: (x - mean) / std
27# For [0,1] -> [-1,1]: mean=0.5, std=0.5
28# Because: (x - 0.5) / 0.5 = 2x - 1
29
30standard_normalize = transforms.Normalize(
31    mean=(0.5, 0.5, 0.5),  # RGB means
32    std=(0.5, 0.5, 0.5),   # RGB stds
33)
34
35# Verify the math
36test_tensor = torch.tensor([0.0, 0.5, 1.0])  # [0,1] range
37# After normalization: [-1, 0, 1]
38normalized = (test_tensor - 0.5) / 0.5
39print(f"Input: {test_tensor.tolist()}")
40print(f"Output: {normalized.tolist()}")  # [-1.0, 0.0, 1.0]

Handling Grayscale Images

For single-channel images like MNIST, use a single mean and std value:

🐍python
1# Grayscale normalization
2grayscale_normalize = transforms.Normalize(
3    mean=(0.5,),
4    std=(0.5,),
5)
6
7# Complete transform for MNIST
8mnist_transform = transforms.Compose([
9    transforms.ToTensor(),           # [0, 255] -> [0, 1], HWC -> CHW
10    transforms.Normalize((0.5,), (0.5,)),  # [0, 1] -> [-1, 1]
11])
12
13# For RGB images
14rgb_transform = transforms.Compose([
15    transforms.ToTensor(),
16    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
17])

Common Bug

If you accidentally use RGB normalization on grayscale images (or vice versa), you'll get a shape mismatch error. Always match your normalization to your image format!

ImageNet Normalization: When to Use It

Some practitioners use ImageNet statistics for normalization:

🐍python
1# ImageNet normalization (per-channel statistics)
2imagenet_normalize = transforms.Normalize(
3    mean=[0.485, 0.456, 0.406],
4    std=[0.229, 0.224, 0.225],
5)
6
7# This produces values roughly in [-2.5, 2.5] for most images
8# Pros: Matches pretrained feature extractors (useful for FID calculation)
9# Cons: Asymmetric range, not centered at 0

Our recommendation: Use [-1, 1] normalization for training, but apply ImageNet normalization when computing FID (since the Inception network was trained on ImageNet-normalized data).


Data Augmentation

Data augmentation increases effective dataset size and improves generalization. However, not all augmentations are appropriate for diffusion models.

Safe Augmentations

These augmentations preserve semantic content and are widely used:

🐍python
1from torchvision import transforms
2
3# Safe augmentations for most image datasets
4safe_augmentations = transforms.Compose([
5    # Horizontal flip: Almost always safe for natural images
6    # (Exception: text, directional content)
7    transforms.RandomHorizontalFlip(p=0.5),
8
9    # Random crop: Creates variety while keeping content
10    transforms.RandomResizedCrop(
11        size=64,
12        scale=(0.8, 1.0),      # Use 80-100% of image
13        ratio=(0.9, 1.1),      # Keep aspect ratio close to original
14    ),
15
16    # Color jitter: Subtle changes are generally safe
17    transforms.ColorJitter(
18        brightness=0.1,
19        contrast=0.1,
20        saturation=0.1,
21        hue=0.02,              # Hue changes should be minimal
22    ),
23])
24
25# Complete training transform
26training_transform = transforms.Compose([
27    transforms.RandomHorizontalFlip(p=0.5),
28    transforms.ToTensor(),
29    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
30])

Augmentations to Avoid

Some augmentations can hurt diffusion model training:

AugmentationProblemWhen It Might Be OK
Vertical flipUnnatural for most objectsAerial/satellite imagery
Rotation (large)Creates border artifactsMedical imaging, symmetry
Heavy color jitterProduces unrealistic colorsStyle transfer datasets
Cutout/ErasingModel learns to generate blanksNever for unconditional
Mixup/CutMixBlends images unnaturallyClassification only
🐍python
1# DON'T do this for diffusion models:
2bad_augmentations = transforms.Compose([
3    transforms.RandomVerticalFlip(p=0.5),      # Unnatural
4    transforms.RandomRotation(45),              # Border artifacts
5    transforms.RandomErasing(p=0.5),            # Model learns to generate holes
6    transforms.ColorJitter(hue=0.5),            # Unrealistic colors
7])
8
9# These can corrupt your training signal!

Dataset-Specific Considerations

🐍python
1def get_augmentation_for_dataset(dataset_name: str) -> transforms.Compose:
2    """
3    Return appropriate augmentations for each dataset type.
4    """
5    if dataset_name == "mnist":
6        # MNIST: No flip (handwriting has orientation)
7        return transforms.Compose([
8            transforms.ToTensor(),
9            transforms.Normalize((0.5,), (0.5,)),
10        ])
11
12    elif dataset_name == "cifar10":
13        # CIFAR-10: Horizontal flip is standard
14        return transforms.Compose([
15            transforms.RandomHorizontalFlip(p=0.5),
16            transforms.ToTensor(),
17            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
18        ])
19
20    elif dataset_name == "celeba":
21        # CelebA: Flip is OK for faces
22        return transforms.Compose([
23            transforms.CenterCrop(178),
24            transforms.Resize(64),
25            transforms.RandomHorizontalFlip(p=0.5),
26            transforms.ToTensor(),
27            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
28        ])
29
30    elif dataset_name == "imagenet":
31        # ImageNet: Random crop + flip is standard
32        return transforms.Compose([
33            transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
34            transforms.RandomHorizontalFlip(p=0.5),
35            transforms.ToTensor(),
36            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
37        ])
38
39    else:
40        # Default: horizontal flip only
41        return transforms.Compose([
42            transforms.RandomHorizontalFlip(p=0.5),
43            transforms.ToTensor(),
44            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
45        ])

Resolution Handling

Diffusion models typically require fixed-size inputs. When your dataset has variable-size images, you need a strategy to standardize them:

Center Crop

Center cropping is the simplest approach. It preserves the central region, which often contains the main subject:

🐍python
1def center_crop_and_resize(
2    target_size: int,
3) -> transforms.Compose:
4    """
5    First resize so smaller edge = target, then center crop.
6
7    This preserves aspect ratio until the final crop.
8    """
9    return transforms.Compose([
10        # Resize smaller edge to target_size
11        transforms.Resize(target_size),
12        # Center crop to square
13        transforms.CenterCrop(target_size),
14    ])
15
16
17# Example: 1920x1080 image -> 256x256
18# Step 1: Resize to 455x256 (smaller edge = 256)
19# Step 2: Center crop to 256x256
20
21transform = transforms.Compose([
22    transforms.Resize(256),
23    transforms.CenterCrop(256),
24    transforms.ToTensor(),
25    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
26])

Random Crop for Training

Random cropping provides data augmentation and uses more of each image:

🐍python
1def random_crop_transform(target_size: int) -> transforms.Compose:
2    """
3    Random crop for training: provides data augmentation.
4    """
5    return transforms.Compose([
6        # Resize so smaller edge is slightly larger than target
7        transforms.Resize(int(target_size * 1.1)),
8        # Random crop to exact target size
9        transforms.RandomCrop(target_size),
10        transforms.RandomHorizontalFlip(p=0.5),
11        transforms.ToTensor(),
12        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
13    ])
14
15
16# For validation/testing, use center crop (deterministic)
17def validation_transform(target_size: int) -> transforms.Compose:
18    return transforms.Compose([
19        transforms.Resize(int(target_size * 1.1)),
20        transforms.CenterCrop(target_size),
21        transforms.ToTensor(),
22        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
23    ])

Handling Extreme Aspect Ratios

Some images have extreme aspect ratios that require special handling:

🐍python
1from PIL import Image
2import torch
3
4def smart_resize(
5    image: Image.Image,
6    target_size: int,
7    max_aspect_ratio: float = 2.0,
8) -> Image.Image:
9    """
10    Resize image while handling extreme aspect ratios.
11
12    If aspect ratio is too extreme, we pad to make it more square
13    before resizing, avoiding severe distortion.
14    """
15    w, h = image.size
16    aspect = max(w/h, h/w)
17
18    if aspect > max_aspect_ratio:
19        # Image is too elongated, pad to make more square
20        if w > h:
21            new_h = int(w / max_aspect_ratio)
22            padding = (new_h - h) // 2
23            padded = Image.new('RGB', (w, new_h), (0, 0, 0))
24            padded.paste(image, (0, padding))
25            image = padded
26        else:
27            new_w = int(h / max_aspect_ratio)
28            padding = (new_w - w) // 2
29            padded = Image.new('RGB', (new_w, h), (0, 0, 0))
30            padded.paste(image, (padding, 0))
31            image = padded
32
33    # Now resize and center crop
34    return transforms.Compose([
35        transforms.Resize(target_size),
36        transforms.CenterCrop(target_size),
37    ])(image)

Color Space Considerations

RGB vs Other Color Spaces

While most diffusion models work in RGB, some researchers have explored other color spaces:

🐍python
1import torch
2from PIL import Image
3
4def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
5    """
6    Convert RGB tensor to LAB color space.
7
8    LAB separates luminance (L) from color (A, B), which can
9    sometimes improve perceptual quality.
10
11    Note: This is illustrative; for production, use kornia or skimage.
12    """
13    # Normalize RGB from [0, 1] to proper range
14    # ... complex conversion ...
15    pass
16
17
18# In practice, stick with RGB unless you have a specific reason:
19# - RGB is simpler and well-studied for diffusion
20# - Pretrained models (Inception for FID) expect RGB
21# - LAB conversion adds complexity and potential bugs

Handling Different Input Formats

🐍python
1from PIL import Image
2import torch
3
4def load_and_preprocess(
5    image_path: str,
6    target_size: int = 256,
7    target_channels: int = 3,
8) -> torch.Tensor:
9    """
10    Load image from disk and preprocess for diffusion model.
11
12    Handles various input formats: RGB, RGBA, grayscale, palette.
13    """
14    image = Image.open(image_path)
15
16    # Handle different image modes
17    if target_channels == 3:
18        if image.mode == 'RGBA':
19            # Blend alpha channel with white background
20            background = Image.new('RGB', image.size, (255, 255, 255))
21            background.paste(image, mask=image.split()[3])
22            image = background
23        elif image.mode == 'P':
24            # Palette mode (GIF, some PNGs)
25            image = image.convert('RGB')
26        elif image.mode == 'L':
27            # Grayscale -> RGB (replicate channels)
28            image = image.convert('RGB')
29        else:
30            image = image.convert('RGB')
31
32    elif target_channels == 1:
33        image = image.convert('L')
34
35    # Standard transform
36    transform = transforms.Compose([
37        transforms.Resize(target_size),
38        transforms.CenterCrop(target_size),
39        transforms.ToTensor(),
40        transforms.Normalize(
41            mean=[0.5] * target_channels,
42            std=[0.5] * target_channels,
43        ),
44    ])
45
46    return transform(image)
47
48
49# Example usage
50tensor = load_and_preprocess("image.png", target_size=64)
51print(f"Output shape: {tensor.shape}")  # [3, 64, 64] or [1, 64, 64]

Complete Preprocessing Pipeline

Here's a complete, production-ready preprocessing pipeline that handles all the edge cases:

🐍python
1import torch
2from torch.utils.data import Dataset, DataLoader
3from torchvision import transforms
4from PIL import Image
5from pathlib import Path
6from typing import Optional, Callable, List, Tuple
7import os
8
9
10class DiffusionImageDataset(Dataset):
11    """
12    Production-ready image dataset for diffusion model training.
13
14    Features:
15    - Handles various image formats (JPEG, PNG, WebP, etc.)
16    - Proper RGB conversion from any input mode
17    - Configurable resolution and augmentation
18    - Reproducible validation split
19    """
20
21    VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
22
23    def __init__(
24        self,
25        root: str,
26        image_size: int = 256,
27        split: str = "train",  # "train" or "val"
28        val_fraction: float = 0.1,
29        seed: int = 42,
30        augment: bool = True,
31    ):
32        """
33        Args:
34            root: Directory containing images
35            image_size: Target image size (square)
36            split: "train" or "val"
37            val_fraction: Fraction of data for validation
38            seed: Random seed for reproducible split
39            augment: Whether to apply augmentation (only for train)
40        """
41        self.root = Path(root)
42        self.image_size = image_size
43        self.split = split
44
45        # Find all valid images
46        self.image_paths = self._find_images()
47
48        # Create reproducible train/val split
49        self._create_split(val_fraction, seed)
50
51        # Build transforms
52        self.transform = self._build_transform(augment and split == "train")
53
54        print(f"Loaded {len(self)} images for {split} split")
55
56    def _find_images(self) -> List[Path]:
57        """Find all valid image files recursively."""
58        paths = []
59        for ext in self.VALID_EXTENSIONS:
60            paths.extend(self.root.rglob(f"*{ext}"))
61            paths.extend(self.root.rglob(f"*{ext.upper()}"))
62        return sorted(paths)
63
64    def _create_split(self, val_fraction: float, seed: int):
65        """Create reproducible train/val split."""
66        import random
67        rng = random.Random(seed)
68
69        indices = list(range(len(self.image_paths)))
70        rng.shuffle(indices)
71
72        val_size = int(len(indices) * val_fraction)
73
74        if self.split == "val":
75            indices = indices[:val_size]
76        else:
77            indices = indices[val_size:]
78
79        self.image_paths = [self.image_paths[i] for i in sorted(indices)]
80
81    def _build_transform(self, augment: bool) -> transforms.Compose:
82        """Build preprocessing transform pipeline."""
83        transform_list = []
84
85        if augment:
86            transform_list.extend([
87                # Slight resize for random crop
88                transforms.Resize(int(self.image_size * 1.1)),
89                transforms.RandomCrop(self.image_size),
90                transforms.RandomHorizontalFlip(p=0.5),
91            ])
92        else:
93            transform_list.extend([
94                transforms.Resize(int(self.image_size * 1.1)),
95                transforms.CenterCrop(self.image_size),
96            ])
97
98        transform_list.extend([
99            transforms.ToTensor(),
100            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
101        ])
102
103        return transforms.Compose(transform_list)
104
105    def _load_image(self, path: Path) -> Image.Image:
106        """Load image and convert to RGB."""
107        image = Image.open(path)
108
109        # Handle all image modes
110        if image.mode == 'RGBA':
111            # Alpha blend with white background
112            background = Image.new('RGB', image.size, (255, 255, 255))
113            background.paste(image, mask=image.split()[3])
114            return background
115        elif image.mode != 'RGB':
116            return image.convert('RGB')
117        return image
118
119    def __len__(self) -> int:
120        return len(self.image_paths)
121
122    def __getitem__(self, idx: int) -> torch.Tensor:
123        image = self._load_image(self.image_paths[idx])
124        return self.transform(image)
125
126
127def create_dataloader(
128    root: str,
129    image_size: int = 256,
130    batch_size: int = 32,
131    split: str = "train",
132    num_workers: int = 4,
133) -> DataLoader:
134    """
135    Create a DataLoader for training or validation.
136
137    Args:
138        root: Directory containing images
139        image_size: Target image size
140        batch_size: Batch size
141        split: "train" or "val"
142        num_workers: Number of data loading workers
143
144    Returns:
145        PyTorch DataLoader
146    """
147    dataset = DiffusionImageDataset(
148        root=root,
149        image_size=image_size,
150        split=split,
151        augment=(split == "train"),
152    )
153
154    return DataLoader(
155        dataset,
156        batch_size=batch_size,
157        shuffle=(split == "train"),
158        num_workers=num_workers,
159        pin_memory=True,  # Faster GPU transfer
160        drop_last=(split == "train"),  # Avoid small final batch
161        persistent_workers=(num_workers > 0),  # Keep workers alive
162    )
163
164
165# Usage example
166if __name__ == "__main__":
167    train_loader = create_dataloader(
168        root="./my_images",
169        image_size=64,
170        batch_size=32,
171        split="train",
172    )
173
174    # Get a batch
175    batch = next(iter(train_loader))
176    print(f"Batch shape: {batch.shape}")  # [32, 3, 64, 64]
177    print(f"Value range: [{batch.min():.2f}, {batch.max():.2f}]")  # [-1, 1]

Performance Tip

Use pin_memory=True\texttt{pin\_memory=True} and multiple workers to maximize GPU utilization. The DataLoader will prefetch batches while the GPU processes the current batch.

Key Takeaways

  1. Use [-1, 1] normalization: This centers data at 0, matching the Gaussian noise distribution and improving training stability.
  2. Keep augmentation simple: Horizontal flip is almost always safe. Avoid aggressive augmentations that create artifacts or unrealistic images.
  3. Handle variable resolutions consistently: Use resize + center crop for validation, resize + random crop for training augmentation.
  4. Always convert to RGB: Handle RGBA (alpha blending), grayscale, and palette images properly to avoid channel mismatches.
  5. Verify your pipeline: After preprocessing, images should have values in [-1, 1] with shape [C, H, W]. Print statistics to confirm.
Looking Ahead: In the next section, we'll build the complete training pipeline, including DataLoader configuration, distributed data loading, and efficient batching strategies.