Learning Objectives
By the end of this section, you will be able to:
- Understand the standard benchmark datasets used for evaluating diffusion models (MNIST, CIFAR-10, CelebA, ImageNet)
- Evaluate dataset characteristics that affect training difficulty and model performance
- Choose the appropriate dataset based on your computational resources and learning goals
- Implement PyTorch dataset classes for both standard and custom image datasets
The Big Picture
Before you can train a diffusion model, you need data to train on. The choice of dataset profoundly affects not just training time, but also whether your model will converge at all. A model that trains beautifully on MNIST might completely fail on ImageNet if you don't adjust your architecture and hyperparameters accordingly.
Why This Matters: The dataset you choose determines the complexity of your learning problem. Starting with a simple dataset like MNIST allows you to verify your implementation before scaling up. Many failed training runs could have been avoided by first validating on an easier benchmark.
In this section, we'll explore the most commonly used datasets for diffusion model research and practice. Each dataset presents unique challenges and teaches different aspects of generative modeling.
Standard Benchmark Datasets
MNIST: The Classic Starting Point
MNIST contains 60,000 training images and 10,000 test images of handwritten digits (0-9). Each image is 28x28 pixels in grayscale. While considered "solved" for classification, MNIST remains valuable for diffusion models because:
- Fast iteration: Training converges in minutes to hours, not days
- Visual verification: You can immediately tell if generated digits look realistic
- Low compute requirements: Works on modest GPUs or even CPUs
1import torch
2from torchvision import datasets, transforms
3
4def get_mnist_dataset(root: str = "./data", train: bool = True):
5 """
6 Load MNIST dataset with preprocessing for diffusion models.
7
8 Note: We normalize to [-1, 1] which is standard for diffusion models.
9 The original pixel values are [0, 255].
10 """
11 transform = transforms.Compose([
12 transforms.ToTensor(),
13 transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1]
14 ])
15
16 dataset = datasets.MNIST(
17 root=root,
18 train=train,
19 download=True,
20 transform=transform,
21 )
22
23 return dataset
24
25# Example usage
26train_dataset = get_mnist_dataset()
27print(f"Dataset size: {len(train_dataset)}")
28print(f"Image shape: {train_dataset[0][0].shape}") # torch.Size([1, 28, 28])
29print(f"Value range: [{train_dataset[0][0].min():.2f}, {train_dataset[0][0].max():.2f}]")CIFAR-10: The First Real Challenge
CIFAR-10 contains 50,000 training images and 10,000 test images across 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck). Each image is 32x32 pixels in RGB. This dataset is significantly harder than MNIST:
- Color images: 3 channels instead of 1 (3x more parameters in first/last conv layers)
- Natural images: Complex textures, backgrounds, and variations
- Standard benchmark: Most diffusion papers report CIFAR-10 FID scores
1def get_cifar10_dataset(root: str = "./data", train: bool = True):
2 """
3 Load CIFAR-10 dataset with standard preprocessing.
4
5 Key considerations:
6 - Random horizontal flip is common for training (data augmentation)
7 - Normalize each channel separately to [-1, 1]
8 """
9 if train:
10 transform = transforms.Compose([
11 transforms.RandomHorizontalFlip(), # Data augmentation
12 transforms.ToTensor(),
13 transforms.Normalize(
14 (0.5, 0.5, 0.5), # Mean for each RGB channel
15 (0.5, 0.5, 0.5), # Std for each RGB channel
16 ),
17 ])
18 else:
19 transform = transforms.Compose([
20 transforms.ToTensor(),
21 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
22 ])
23
24 dataset = datasets.CIFAR10(
25 root=root,
26 train=train,
27 download=True,
28 transform=transform,
29 )
30
31 return dataset
32
33# Memory consideration
34train_dataset = get_cifar10_dataset()
35image, label = train_dataset[0]
36memory_per_image = image.numel() * 4 # 4 bytes per float32
37print(f"Image shape: {image.shape}") # [3, 32, 32]
38print(f"Memory per image: {memory_per_image / 1024:.2f} KB")CelebA: Human Faces
CelebA (Celebrity Faces Attributes) contains over 200,000 celebrity face images with 40 attribute annotations. Images are typically cropped and resized to 64x64 or 128x128 for diffusion model training. Key characteristics:
- Highly aligned: All faces are roughly centered and similarly scaled, making the learning problem more tractable
- Attribute labels: Useful for conditional generation (e.g., generate faces with glasses, blonde hair)
- Larger resolution: Commonly used at 64x64 or 128x128, allowing more detail than CIFAR-10
1from torchvision.datasets import CelebA
2
3def get_celeba_dataset(
4 root: str = "./data",
5 split: str = "train",
6 image_size: int = 64,
7):
8 """
9 Load CelebA dataset with center crop and resize.
10
11 The original images are 178x218. We center crop to 178x178
12 (square) then resize to the target resolution.
13 """
14 transform = transforms.Compose([
15 transforms.CenterCrop(178), # Crop to square
16 transforms.Resize(image_size),
17 transforms.RandomHorizontalFlip(),
18 transforms.ToTensor(),
19 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
20 ])
21
22 dataset = CelebA(
23 root=root,
24 split=split, # "train", "valid", or "test"
25 transform=transform,
26 download=True,
27 )
28
29 return dataset
30
31# Note: CelebA download can be slow and may require manual steps
32# due to Google Drive rate limiting. Consider downloading manually.ImageNet: The Ultimate Challenge
ImageNet (ILSVRC 2012) contains over 1.2 million training images across 1,000 classes. Images vary in size and are typically resized to 256x256 or 512x512. This is the gold standard for evaluating generative models at scale:
- Diversity: 1,000 categories from dogs to vehicles to household objects
- Scale: Requires significant compute (days to weeks of training on multiple GPUs)
- State-of-the-art benchmark: The best diffusion models (like DiT, EDM) compete on ImageNet metrics
Resource Warning
1from torchvision.datasets import ImageFolder
2import os
3
4def get_imagenet_dataset(
5 root: str,
6 split: str = "train",
7 image_size: int = 256,
8):
9 """
10 Load ImageNet dataset.
11
12 Note: ImageNet must be downloaded separately (not available via
13 torchvision's automatic download). The dataset should be organized as:
14
15 root/
16 train/
17 n01440764/ # synset ID
18 ILSVRC2012_val_00000001.JPEG
19 ...
20 n01443537/
21 ...
22 val/
23 ... (same structure)
24 """
25 data_path = os.path.join(root, split)
26
27 if split == "train":
28 transform = transforms.Compose([
29 transforms.Resize(image_size + 32), # Slightly larger for random crop
30 transforms.RandomCrop(image_size),
31 transforms.RandomHorizontalFlip(),
32 transforms.ToTensor(),
33 transforms.Normalize(
34 mean=[0.485, 0.456, 0.406], # ImageNet statistics
35 std=[0.229, 0.224, 0.225],
36 ),
37 ])
38 else:
39 transform = transforms.Compose([
40 transforms.Resize(image_size + 32),
41 transforms.CenterCrop(image_size),
42 transforms.ToTensor(),
43 transforms.Normalize(
44 mean=[0.485, 0.456, 0.406],
45 std=[0.229, 0.224, 0.225],
46 ),
47 ])
48
49 dataset = ImageFolder(data_path, transform=transform)
50 return datasetDataset Characteristics
Understanding the key characteristics of your dataset helps you set appropriate expectations and hyperparameters:
| Dataset | Resolution | Channels | Train Size | Complexity | Typical FID |
|---|---|---|---|---|---|
| MNIST | 28x28 | 1 (grayscale) | 60,000 | Low | ~1-3 |
| Fashion-MNIST | 28x28 | 1 (grayscale) | 60,000 | Low-Medium | ~2-5 |
| CIFAR-10 | 32x32 | 3 (RGB) | 50,000 | Medium | ~2-4 |
| CelebA (64) | 64x64 | 3 (RGB) | 162,770 | Medium-High | ~2-5 |
| CelebA (128) | 128x128 | 3 (RGB) | 162,770 | High | ~3-8 |
| ImageNet (256) | 256x256 | 3 (RGB) | 1,281,167 | Very High | ~2-5 |
What Makes a Dataset "Hard"?
Several factors contribute to dataset difficulty:
- Resolution: More pixels mean exponentially more parameters and longer training. Going from 32x32 to 64x64 quadruples the pixel count.
- Diversity: A dataset of centered faces is easier than one with varied backgrounds, poses, and object types.
- Fine details: High-frequency textures (hair, fur, fabric) are notoriously difficult to generate well.
- Semantic complexity: Generating coherent objects is harder than abstract patterns.
Choosing the Right Dataset
Your choice of dataset should depend on your goals:
For Learning and Debugging
Recommendation: Start with MNIST. If your implementation is correct, you should see recognizable digits after just a few epochs. This fast feedback loop is invaluable for debugging.
- MNIST: First implementation, verify architecture works
- Fashion-MNIST: Slightly harder, good for testing improvements
- CIFAR-10: Standard benchmark, test real performance
For Research and Comparison
Academic papers typically report results on:
- CIFAR-10: Unconditional generation, most commonly cited
- ImageNet: Class-conditional generation, gold standard for scale
- LSUN: Specific domains (bedrooms, churches, cats)
For Production Applications
Real applications often require custom datasets:
- Art generation: Curated collections of artwork
- Medical imaging: X-rays, MRIs, histology slides
- Design: Fashion, architecture, product images
Working with Custom Datasets
For real applications, you'll often need to create custom datasets. PyTorch makes this straightforward with the class:
1import os
2from PIL import Image
3from torch.utils.data import Dataset
4from typing import Optional, Callable, List
5import torch
6
7class ImageFolderDataset(Dataset):
8 """
9 Custom dataset for a folder of images.
10
11 Expected structure:
12 root/
13 image1.png
14 image2.jpg
15 image3.jpeg
16 ...
17
18 All common image formats are supported.
19 """
20
21 EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
22
23 def __init__(
24 self,
25 root: str,
26 transform: Optional[Callable] = None,
27 max_samples: Optional[int] = None,
28 ):
29 """
30 Args:
31 root: Path to folder containing images
32 transform: Optional transform to apply to images
33 max_samples: Optional limit on number of samples (for debugging)
34 """
35 self.root = root
36 self.transform = transform
37
38 # Find all valid image files
39 self.image_paths: List[str] = []
40 for filename in sorted(os.listdir(root)):
41 ext = os.path.splitext(filename)[1].lower()
42 if ext in self.EXTENSIONS:
43 self.image_paths.append(os.path.join(root, filename))
44
45 if max_samples is not None:
46 self.image_paths = self.image_paths[:max_samples]
47
48 if len(self.image_paths) == 0:
49 raise ValueError(f"No images found in {root}")
50
51 print(f"Found {len(self.image_paths)} images in {root}")
52
53 def __len__(self) -> int:
54 return len(self.image_paths)
55
56 def __getitem__(self, idx: int) -> torch.Tensor:
57 img_path = self.image_paths[idx]
58
59 # Load image and convert to RGB (handles grayscale, RGBA, etc.)
60 image = Image.open(img_path).convert("RGB")
61
62 if self.transform is not None:
63 image = self.transform(image)
64
65 # For unconditional generation, we only return the image
66 # For conditional, you might return (image, label) or (image, text)
67 return image
68
69
70# Usage example
71transform = transforms.Compose([
72 transforms.Resize(64),
73 transforms.CenterCrop(64),
74 transforms.RandomHorizontalFlip(),
75 transforms.ToTensor(),
76 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
77])
78
79# Create dataset
80dataset = ImageFolderDataset(
81 root="./my_images",
82 transform=transform,
83)Handling Variable-Size Images
Real-world datasets often contain images of different sizes. Here's how to handle them:
1class VariableSizeImageDataset(Dataset):
2 """
3 Dataset that handles images of different sizes by resizing
4 while maintaining aspect ratio.
5 """
6
7 def __init__(
8 self,
9 root: str,
10 target_size: int = 256,
11 crop_mode: str = "center", # "center" or "random"
12 ):
13 self.root = root
14 self.target_size = target_size
15 self.crop_mode = crop_mode
16
17 # Find images (same as before)
18 self.image_paths = self._find_images(root)
19
20 # Build transform based on crop mode
21 self.transform = self._build_transform()
22
23 def _build_transform(self):
24 """Build transform pipeline that handles variable sizes."""
25 transform_list = [
26 # First, resize so the smaller edge equals target_size
27 transforms.Resize(
28 self.target_size,
29 interpolation=transforms.InterpolationMode.BILINEAR,
30 ),
31 ]
32
33 # Then crop to square
34 if self.crop_mode == "random":
35 transform_list.append(transforms.RandomCrop(self.target_size))
36 else:
37 transform_list.append(transforms.CenterCrop(self.target_size))
38
39 # Standard transforms
40 transform_list.extend([
41 transforms.RandomHorizontalFlip(),
42 transforms.ToTensor(),
43 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
44 ])
45
46 return transforms.Compose(transform_list)
47
48 def _find_images(self, root: str) -> List[str]:
49 extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
50 paths = []
51 for dirpath, _, filenames in os.walk(root):
52 for filename in filenames:
53 if os.path.splitext(filename)[1].lower() in extensions:
54 paths.append(os.path.join(dirpath, filename))
55 return sorted(paths)
56
57 def __len__(self):
58 return len(self.image_paths)
59
60 def __getitem__(self, idx):
61 image = Image.open(self.image_paths[idx]).convert("RGB")
62 return self.transform(image)Practical Implementation
Here's a unified function that handles all common datasets:
1from torchvision import datasets
2from torch.utils.data import Dataset
3from typing import Union, Tuple
4import torch
5
6def get_dataset(
7 name: str,
8 root: str = "./data",
9 image_size: int = 32,
10 train: bool = True,
11) -> Dataset:
12 """
13 Unified dataset loader for common diffusion model benchmarks.
14
15 Args:
16 name: Dataset name ("mnist", "cifar10", "celeba", "custom")
17 root: Root directory for data
18 image_size: Target image size (will resize/crop if needed)
19 train: Whether to load training or validation split
20
21 Returns:
22 PyTorch Dataset object
23 """
24 name = name.lower()
25
26 # Common transform for all color image datasets
27 color_transform = transforms.Compose([
28 transforms.Resize(image_size),
29 transforms.CenterCrop(image_size),
30 transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x),
31 transforms.ToTensor(),
32 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
33 ])
34
35 # Grayscale transform
36 gray_transform = transforms.Compose([
37 transforms.Resize(image_size),
38 transforms.CenterCrop(image_size),
39 transforms.ToTensor(),
40 transforms.Normalize((0.5,), (0.5,)),
41 ])
42
43 if name == "mnist":
44 return datasets.MNIST(
45 root=root,
46 train=train,
47 download=True,
48 transform=gray_transform,
49 )
50
51 elif name == "fashion_mnist":
52 return datasets.FashionMNIST(
53 root=root,
54 train=train,
55 download=True,
56 transform=gray_transform,
57 )
58
59 elif name == "cifar10":
60 return datasets.CIFAR10(
61 root=root,
62 train=train,
63 download=True,
64 transform=color_transform,
65 )
66
67 elif name == "cifar100":
68 return datasets.CIFAR100(
69 root=root,
70 train=train,
71 download=True,
72 transform=color_transform,
73 )
74
75 elif name == "celeba":
76 split = "train" if train else "valid"
77 # CelebA needs special handling for the crop
78 celeba_transform = transforms.Compose([
79 transforms.CenterCrop(178),
80 transforms.Resize(image_size),
81 transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x),
82 transforms.ToTensor(),
83 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
84 ])
85 return datasets.CelebA(
86 root=root,
87 split=split,
88 download=True,
89 transform=celeba_transform,
90 )
91
92 else:
93 raise ValueError(f"Unknown dataset: {name}")
94
95
96# Usage
97dataset = get_dataset("cifar10", image_size=32, train=True)
98print(f"Loaded {len(dataset)} samples")
99
100# Get a sample
101sample = dataset[0]
102if isinstance(sample, tuple):
103 image, label = sample
104else:
105 image = sample
106
107print(f"Image shape: {image.shape}")
108print(f"Value range: [{image.min():.2f}, {image.max():.2f}]")Image Value Range
Key Takeaways
- Start simple: Begin with MNIST to verify your implementation, then scale up to CIFAR-10 before attempting larger datasets.
- Understand complexity: Dataset difficulty scales with resolution, diversity, and semantic complexity. Choose datasets appropriate for your compute budget.
- Normalize to [-1, 1]: Standard practice for diffusion models. This centers data at 0, matching the Gaussian noise distribution.
- Use standard benchmarks: CIFAR-10 and ImageNet have established baselines, making it easy to compare your results with published work.
- Custom datasets are easy: PyTorch's Dataset class makes it straightforward to work with your own image collections.
Looking Ahead: In the next section, we'll dive deep into data preprocessing, covering normalization strategies, augmentation techniques, and how to handle different image formats and resolutions.