Learning Objectives
By the end of this section, you will be able to:
- Compare normalization strategies ([-1, 1] vs [0, 1]) and understand why [-1, 1] is standard for diffusion
- Apply appropriate data augmentation techniques that improve model generalization without corrupting the learning signal
- Handle different image resolutions and aspect ratios consistently across your dataset
- Implement a complete, production-ready preprocessing pipeline in PyTorch
The Big Picture
Data preprocessing is often underestimated, but it directly impacts training stability and generation quality. A seemingly minor choice like normalizing to [0, 1] instead of [-1, 1] can cause training to fail or produce washed-out images. Understanding the "why" behind preprocessing choices helps you debug issues and adapt to new datasets.
The Core Principle: Preprocessing should make the model's job easier. For diffusion models, this means centering data at zero (matching the Gaussian noise distribution) and ensuring consistent scales across channels.
In this section, we'll build a principled preprocessing pipeline, explaining each choice and its impact on training.
Normalization Strategies
Why [-1, 1] is the Standard
Raw image pixels are typically in [0, 255]. We need to normalize them, but the choice of target range matters significantly for diffusion models:
| Range | Formula | Pros | Cons |
|---|---|---|---|
| [0, 1] | x / 255 | Simple, intuitive | Mean = 0.5, not centered |
| [-1, 1] | (x / 127.5) - 1 | Centered at 0 | Standard for diffusion |
| ImageNet | (x - μ) / σ | Dataset-specific | Complex, varies by dataset |
The key reason [-1, 1] works best for diffusion models is the Gaussian noise process:
- The noise is centered at 0
- The final noisy state should be approximately
- Data centered at 0 mixes naturally with zero-centered noise
1import torch
2from torchvision import transforms
3from typing import Tuple
4
5def normalize_to_minus_one_one(x: torch.Tensor) -> torch.Tensor:
6 """
7 Normalize tensor from [0, 1] to [-1, 1].
8
9 ToTensor() already converts [0, 255] -> [0, 1], so we just shift:
10 y = 2x - 1
11 """
12 return x * 2.0 - 1.0
13
14
15def denormalize_from_minus_one_one(x: torch.Tensor) -> torch.Tensor:
16 """
17 Reverse normalization: [-1, 1] -> [0, 1].
18
19 For display/saving: x_display = (x + 1) / 2
20 Then scale to [0, 255] if needed for saving.
21 """
22 return (x + 1.0) / 2.0
23
24
25# Standard approach using transforms.Normalize
26# Normalize(mean, std) computes: (x - mean) / std
27# For [0,1] -> [-1,1]: mean=0.5, std=0.5
28# Because: (x - 0.5) / 0.5 = 2x - 1
29
30standard_normalize = transforms.Normalize(
31 mean=(0.5, 0.5, 0.5), # RGB means
32 std=(0.5, 0.5, 0.5), # RGB stds
33)
34
35# Verify the math
36test_tensor = torch.tensor([0.0, 0.5, 1.0]) # [0,1] range
37# After normalization: [-1, 0, 1]
38normalized = (test_tensor - 0.5) / 0.5
39print(f"Input: {test_tensor.tolist()}")
40print(f"Output: {normalized.tolist()}") # [-1.0, 0.0, 1.0]Handling Grayscale Images
For single-channel images like MNIST, use a single mean and std value:
1# Grayscale normalization
2grayscale_normalize = transforms.Normalize(
3 mean=(0.5,),
4 std=(0.5,),
5)
6
7# Complete transform for MNIST
8mnist_transform = transforms.Compose([
9 transforms.ToTensor(), # [0, 255] -> [0, 1], HWC -> CHW
10 transforms.Normalize((0.5,), (0.5,)), # [0, 1] -> [-1, 1]
11])
12
13# For RGB images
14rgb_transform = transforms.Compose([
15 transforms.ToTensor(),
16 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
17])Common Bug
ImageNet Normalization: When to Use It
Some practitioners use ImageNet statistics for normalization:
1# ImageNet normalization (per-channel statistics)
2imagenet_normalize = transforms.Normalize(
3 mean=[0.485, 0.456, 0.406],
4 std=[0.229, 0.224, 0.225],
5)
6
7# This produces values roughly in [-2.5, 2.5] for most images
8# Pros: Matches pretrained feature extractors (useful for FID calculation)
9# Cons: Asymmetric range, not centered at 0Our recommendation: Use [-1, 1] normalization for training, but apply ImageNet normalization when computing FID (since the Inception network was trained on ImageNet-normalized data).
Data Augmentation
Data augmentation increases effective dataset size and improves generalization. However, not all augmentations are appropriate for diffusion models.
Safe Augmentations
These augmentations preserve semantic content and are widely used:
1from torchvision import transforms
2
3# Safe augmentations for most image datasets
4safe_augmentations = transforms.Compose([
5 # Horizontal flip: Almost always safe for natural images
6 # (Exception: text, directional content)
7 transforms.RandomHorizontalFlip(p=0.5),
8
9 # Random crop: Creates variety while keeping content
10 transforms.RandomResizedCrop(
11 size=64,
12 scale=(0.8, 1.0), # Use 80-100% of image
13 ratio=(0.9, 1.1), # Keep aspect ratio close to original
14 ),
15
16 # Color jitter: Subtle changes are generally safe
17 transforms.ColorJitter(
18 brightness=0.1,
19 contrast=0.1,
20 saturation=0.1,
21 hue=0.02, # Hue changes should be minimal
22 ),
23])
24
25# Complete training transform
26training_transform = transforms.Compose([
27 transforms.RandomHorizontalFlip(p=0.5),
28 transforms.ToTensor(),
29 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
30])Augmentations to Avoid
Some augmentations can hurt diffusion model training:
| Augmentation | Problem | When It Might Be OK |
|---|---|---|
| Vertical flip | Unnatural for most objects | Aerial/satellite imagery |
| Rotation (large) | Creates border artifacts | Medical imaging, symmetry |
| Heavy color jitter | Produces unrealistic colors | Style transfer datasets |
| Cutout/Erasing | Model learns to generate blanks | Never for unconditional |
| Mixup/CutMix | Blends images unnaturally | Classification only |
1# DON'T do this for diffusion models:
2bad_augmentations = transforms.Compose([
3 transforms.RandomVerticalFlip(p=0.5), # Unnatural
4 transforms.RandomRotation(45), # Border artifacts
5 transforms.RandomErasing(p=0.5), # Model learns to generate holes
6 transforms.ColorJitter(hue=0.5), # Unrealistic colors
7])
8
9# These can corrupt your training signal!Dataset-Specific Considerations
1def get_augmentation_for_dataset(dataset_name: str) -> transforms.Compose:
2 """
3 Return appropriate augmentations for each dataset type.
4 """
5 if dataset_name == "mnist":
6 # MNIST: No flip (handwriting has orientation)
7 return transforms.Compose([
8 transforms.ToTensor(),
9 transforms.Normalize((0.5,), (0.5,)),
10 ])
11
12 elif dataset_name == "cifar10":
13 # CIFAR-10: Horizontal flip is standard
14 return transforms.Compose([
15 transforms.RandomHorizontalFlip(p=0.5),
16 transforms.ToTensor(),
17 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
18 ])
19
20 elif dataset_name == "celeba":
21 # CelebA: Flip is OK for faces
22 return transforms.Compose([
23 transforms.CenterCrop(178),
24 transforms.Resize(64),
25 transforms.RandomHorizontalFlip(p=0.5),
26 transforms.ToTensor(),
27 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
28 ])
29
30 elif dataset_name == "imagenet":
31 # ImageNet: Random crop + flip is standard
32 return transforms.Compose([
33 transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
34 transforms.RandomHorizontalFlip(p=0.5),
35 transforms.ToTensor(),
36 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
37 ])
38
39 else:
40 # Default: horizontal flip only
41 return transforms.Compose([
42 transforms.RandomHorizontalFlip(p=0.5),
43 transforms.ToTensor(),
44 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
45 ])Resolution Handling
Diffusion models typically require fixed-size inputs. When your dataset has variable-size images, you need a strategy to standardize them:
Center Crop
Center cropping is the simplest approach. It preserves the central region, which often contains the main subject:
1def center_crop_and_resize(
2 target_size: int,
3) -> transforms.Compose:
4 """
5 First resize so smaller edge = target, then center crop.
6
7 This preserves aspect ratio until the final crop.
8 """
9 return transforms.Compose([
10 # Resize smaller edge to target_size
11 transforms.Resize(target_size),
12 # Center crop to square
13 transforms.CenterCrop(target_size),
14 ])
15
16
17# Example: 1920x1080 image -> 256x256
18# Step 1: Resize to 455x256 (smaller edge = 256)
19# Step 2: Center crop to 256x256
20
21transform = transforms.Compose([
22 transforms.Resize(256),
23 transforms.CenterCrop(256),
24 transforms.ToTensor(),
25 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
26])Random Crop for Training
Random cropping provides data augmentation and uses more of each image:
1def random_crop_transform(target_size: int) -> transforms.Compose:
2 """
3 Random crop for training: provides data augmentation.
4 """
5 return transforms.Compose([
6 # Resize so smaller edge is slightly larger than target
7 transforms.Resize(int(target_size * 1.1)),
8 # Random crop to exact target size
9 transforms.RandomCrop(target_size),
10 transforms.RandomHorizontalFlip(p=0.5),
11 transforms.ToTensor(),
12 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
13 ])
14
15
16# For validation/testing, use center crop (deterministic)
17def validation_transform(target_size: int) -> transforms.Compose:
18 return transforms.Compose([
19 transforms.Resize(int(target_size * 1.1)),
20 transforms.CenterCrop(target_size),
21 transforms.ToTensor(),
22 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
23 ])Handling Extreme Aspect Ratios
Some images have extreme aspect ratios that require special handling:
1from PIL import Image
2import torch
3
4def smart_resize(
5 image: Image.Image,
6 target_size: int,
7 max_aspect_ratio: float = 2.0,
8) -> Image.Image:
9 """
10 Resize image while handling extreme aspect ratios.
11
12 If aspect ratio is too extreme, we pad to make it more square
13 before resizing, avoiding severe distortion.
14 """
15 w, h = image.size
16 aspect = max(w/h, h/w)
17
18 if aspect > max_aspect_ratio:
19 # Image is too elongated, pad to make more square
20 if w > h:
21 new_h = int(w / max_aspect_ratio)
22 padding = (new_h - h) // 2
23 padded = Image.new('RGB', (w, new_h), (0, 0, 0))
24 padded.paste(image, (0, padding))
25 image = padded
26 else:
27 new_w = int(h / max_aspect_ratio)
28 padding = (new_w - w) // 2
29 padded = Image.new('RGB', (new_w, h), (0, 0, 0))
30 padded.paste(image, (padding, 0))
31 image = padded
32
33 # Now resize and center crop
34 return transforms.Compose([
35 transforms.Resize(target_size),
36 transforms.CenterCrop(target_size),
37 ])(image)Color Space Considerations
RGB vs Other Color Spaces
While most diffusion models work in RGB, some researchers have explored other color spaces:
1import torch
2from PIL import Image
3
4def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
5 """
6 Convert RGB tensor to LAB color space.
7
8 LAB separates luminance (L) from color (A, B), which can
9 sometimes improve perceptual quality.
10
11 Note: This is illustrative; for production, use kornia or skimage.
12 """
13 # Normalize RGB from [0, 1] to proper range
14 # ... complex conversion ...
15 pass
16
17
18# In practice, stick with RGB unless you have a specific reason:
19# - RGB is simpler and well-studied for diffusion
20# - Pretrained models (Inception for FID) expect RGB
21# - LAB conversion adds complexity and potential bugsHandling Different Input Formats
1from PIL import Image
2import torch
3
4def load_and_preprocess(
5 image_path: str,
6 target_size: int = 256,
7 target_channels: int = 3,
8) -> torch.Tensor:
9 """
10 Load image from disk and preprocess for diffusion model.
11
12 Handles various input formats: RGB, RGBA, grayscale, palette.
13 """
14 image = Image.open(image_path)
15
16 # Handle different image modes
17 if target_channels == 3:
18 if image.mode == 'RGBA':
19 # Blend alpha channel with white background
20 background = Image.new('RGB', image.size, (255, 255, 255))
21 background.paste(image, mask=image.split()[3])
22 image = background
23 elif image.mode == 'P':
24 # Palette mode (GIF, some PNGs)
25 image = image.convert('RGB')
26 elif image.mode == 'L':
27 # Grayscale -> RGB (replicate channels)
28 image = image.convert('RGB')
29 else:
30 image = image.convert('RGB')
31
32 elif target_channels == 1:
33 image = image.convert('L')
34
35 # Standard transform
36 transform = transforms.Compose([
37 transforms.Resize(target_size),
38 transforms.CenterCrop(target_size),
39 transforms.ToTensor(),
40 transforms.Normalize(
41 mean=[0.5] * target_channels,
42 std=[0.5] * target_channels,
43 ),
44 ])
45
46 return transform(image)
47
48
49# Example usage
50tensor = load_and_preprocess("image.png", target_size=64)
51print(f"Output shape: {tensor.shape}") # [3, 64, 64] or [1, 64, 64]Complete Preprocessing Pipeline
Here's a complete, production-ready preprocessing pipeline that handles all the edge cases:
1import torch
2from torch.utils.data import Dataset, DataLoader
3from torchvision import transforms
4from PIL import Image
5from pathlib import Path
6from typing import Optional, Callable, List, Tuple
7import os
8
9
10class DiffusionImageDataset(Dataset):
11 """
12 Production-ready image dataset for diffusion model training.
13
14 Features:
15 - Handles various image formats (JPEG, PNG, WebP, etc.)
16 - Proper RGB conversion from any input mode
17 - Configurable resolution and augmentation
18 - Reproducible validation split
19 """
20
21 VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
22
23 def __init__(
24 self,
25 root: str,
26 image_size: int = 256,
27 split: str = "train", # "train" or "val"
28 val_fraction: float = 0.1,
29 seed: int = 42,
30 augment: bool = True,
31 ):
32 """
33 Args:
34 root: Directory containing images
35 image_size: Target image size (square)
36 split: "train" or "val"
37 val_fraction: Fraction of data for validation
38 seed: Random seed for reproducible split
39 augment: Whether to apply augmentation (only for train)
40 """
41 self.root = Path(root)
42 self.image_size = image_size
43 self.split = split
44
45 # Find all valid images
46 self.image_paths = self._find_images()
47
48 # Create reproducible train/val split
49 self._create_split(val_fraction, seed)
50
51 # Build transforms
52 self.transform = self._build_transform(augment and split == "train")
53
54 print(f"Loaded {len(self)} images for {split} split")
55
56 def _find_images(self) -> List[Path]:
57 """Find all valid image files recursively."""
58 paths = []
59 for ext in self.VALID_EXTENSIONS:
60 paths.extend(self.root.rglob(f"*{ext}"))
61 paths.extend(self.root.rglob(f"*{ext.upper()}"))
62 return sorted(paths)
63
64 def _create_split(self, val_fraction: float, seed: int):
65 """Create reproducible train/val split."""
66 import random
67 rng = random.Random(seed)
68
69 indices = list(range(len(self.image_paths)))
70 rng.shuffle(indices)
71
72 val_size = int(len(indices) * val_fraction)
73
74 if self.split == "val":
75 indices = indices[:val_size]
76 else:
77 indices = indices[val_size:]
78
79 self.image_paths = [self.image_paths[i] for i in sorted(indices)]
80
81 def _build_transform(self, augment: bool) -> transforms.Compose:
82 """Build preprocessing transform pipeline."""
83 transform_list = []
84
85 if augment:
86 transform_list.extend([
87 # Slight resize for random crop
88 transforms.Resize(int(self.image_size * 1.1)),
89 transforms.RandomCrop(self.image_size),
90 transforms.RandomHorizontalFlip(p=0.5),
91 ])
92 else:
93 transform_list.extend([
94 transforms.Resize(int(self.image_size * 1.1)),
95 transforms.CenterCrop(self.image_size),
96 ])
97
98 transform_list.extend([
99 transforms.ToTensor(),
100 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
101 ])
102
103 return transforms.Compose(transform_list)
104
105 def _load_image(self, path: Path) -> Image.Image:
106 """Load image and convert to RGB."""
107 image = Image.open(path)
108
109 # Handle all image modes
110 if image.mode == 'RGBA':
111 # Alpha blend with white background
112 background = Image.new('RGB', image.size, (255, 255, 255))
113 background.paste(image, mask=image.split()[3])
114 return background
115 elif image.mode != 'RGB':
116 return image.convert('RGB')
117 return image
118
119 def __len__(self) -> int:
120 return len(self.image_paths)
121
122 def __getitem__(self, idx: int) -> torch.Tensor:
123 image = self._load_image(self.image_paths[idx])
124 return self.transform(image)
125
126
127def create_dataloader(
128 root: str,
129 image_size: int = 256,
130 batch_size: int = 32,
131 split: str = "train",
132 num_workers: int = 4,
133) -> DataLoader:
134 """
135 Create a DataLoader for training or validation.
136
137 Args:
138 root: Directory containing images
139 image_size: Target image size
140 batch_size: Batch size
141 split: "train" or "val"
142 num_workers: Number of data loading workers
143
144 Returns:
145 PyTorch DataLoader
146 """
147 dataset = DiffusionImageDataset(
148 root=root,
149 image_size=image_size,
150 split=split,
151 augment=(split == "train"),
152 )
153
154 return DataLoader(
155 dataset,
156 batch_size=batch_size,
157 shuffle=(split == "train"),
158 num_workers=num_workers,
159 pin_memory=True, # Faster GPU transfer
160 drop_last=(split == "train"), # Avoid small final batch
161 persistent_workers=(num_workers > 0), # Keep workers alive
162 )
163
164
165# Usage example
166if __name__ == "__main__":
167 train_loader = create_dataloader(
168 root="./my_images",
169 image_size=64,
170 batch_size=32,
171 split="train",
172 )
173
174 # Get a batch
175 batch = next(iter(train_loader))
176 print(f"Batch shape: {batch.shape}") # [32, 3, 64, 64]
177 print(f"Value range: [{batch.min():.2f}, {batch.max():.2f}]") # [-1, 1]Performance Tip
Key Takeaways
- Use [-1, 1] normalization: This centers data at 0, matching the Gaussian noise distribution and improving training stability.
- Keep augmentation simple: Horizontal flip is almost always safe. Avoid aggressive augmentations that create artifacts or unrealistic images.
- Handle variable resolutions consistently: Use resize + center crop for validation, resize + random crop for training augmentation.
- Always convert to RGB: Handle RGBA (alpha blending), grayscale, and palette images properly to avoid channel mismatches.
- Verify your pipeline: After preprocessing, images should have values in [-1, 1] with shape [C, H, W]. Print statistics to confirm.
Looking Ahead: In the next section, we'll build the complete training pipeline, including DataLoader configuration, distributed data loading, and efficient batching strategies.