Chapter 10
15 min read
Section 48 of 76

Dataset Selection

Dataset Preparation

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the standard benchmark datasets used for evaluating diffusion models (MNIST, CIFAR-10, CelebA, ImageNet)
  2. Evaluate dataset characteristics that affect training difficulty and model performance
  3. Choose the appropriate dataset based on your computational resources and learning goals
  4. Implement PyTorch dataset classes for both standard and custom image datasets

The Big Picture

Before you can train a diffusion model, you need data to train on. The choice of dataset profoundly affects not just training time, but also whether your model will converge at all. A model that trains beautifully on MNIST might completely fail on ImageNet if you don't adjust your architecture and hyperparameters accordingly.

Why This Matters: The dataset you choose determines the complexity of your learning problem. Starting with a simple dataset like MNIST allows you to verify your implementation before scaling up. Many failed training runs could have been avoided by first validating on an easier benchmark.

In this section, we'll explore the most commonly used datasets for diffusion model research and practice. Each dataset presents unique challenges and teaches different aspects of generative modeling.


Standard Benchmark Datasets

MNIST: The Classic Starting Point

MNIST contains 60,000 training images and 10,000 test images of handwritten digits (0-9). Each image is 28x28 pixels in grayscale. While considered "solved" for classification, MNIST remains valuable for diffusion models because:

  • Fast iteration: Training converges in minutes to hours, not days
  • Visual verification: You can immediately tell if generated digits look realistic
  • Low compute requirements: Works on modest GPUs or even CPUs
🐍python
1import torch
2from torchvision import datasets, transforms
3
4def get_mnist_dataset(root: str = "./data", train: bool = True):
5    """
6    Load MNIST dataset with preprocessing for diffusion models.
7
8    Note: We normalize to [-1, 1] which is standard for diffusion models.
9    The original pixel values are [0, 255].
10    """
11    transform = transforms.Compose([
12        transforms.ToTensor(),
13        transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
14    ])
15
16    dataset = datasets.MNIST(
17        root=root,
18        train=train,
19        download=True,
20        transform=transform,
21    )
22
23    return dataset
24
25# Example usage
26train_dataset = get_mnist_dataset()
27print(f"Dataset size: {len(train_dataset)}")
28print(f"Image shape: {train_dataset[0][0].shape}")  # torch.Size([1, 28, 28])
29print(f"Value range: [{train_dataset[0][0].min():.2f}, {train_dataset[0][0].max():.2f}]")

CIFAR-10: The First Real Challenge

CIFAR-10 contains 50,000 training images and 10,000 test images across 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck). Each image is 32x32 pixels in RGB. This dataset is significantly harder than MNIST:

  • Color images: 3 channels instead of 1 (3x more parameters in first/last conv layers)
  • Natural images: Complex textures, backgrounds, and variations
  • Standard benchmark: Most diffusion papers report CIFAR-10 FID scores
🐍python
1def get_cifar10_dataset(root: str = "./data", train: bool = True):
2    """
3    Load CIFAR-10 dataset with standard preprocessing.
4
5    Key considerations:
6    - Random horizontal flip is common for training (data augmentation)
7    - Normalize each channel separately to [-1, 1]
8    """
9    if train:
10        transform = transforms.Compose([
11            transforms.RandomHorizontalFlip(),  # Data augmentation
12            transforms.ToTensor(),
13            transforms.Normalize(
14                (0.5, 0.5, 0.5),  # Mean for each RGB channel
15                (0.5, 0.5, 0.5),  # Std for each RGB channel
16            ),
17        ])
18    else:
19        transform = transforms.Compose([
20            transforms.ToTensor(),
21            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
22        ])
23
24    dataset = datasets.CIFAR10(
25        root=root,
26        train=train,
27        download=True,
28        transform=transform,
29    )
30
31    return dataset
32
33# Memory consideration
34train_dataset = get_cifar10_dataset()
35image, label = train_dataset[0]
36memory_per_image = image.numel() * 4  # 4 bytes per float32
37print(f"Image shape: {image.shape}")  # [3, 32, 32]
38print(f"Memory per image: {memory_per_image / 1024:.2f} KB")

CelebA: Human Faces

CelebA (Celebrity Faces Attributes) contains over 200,000 celebrity face images with 40 attribute annotations. Images are typically cropped and resized to 64x64 or 128x128 for diffusion model training. Key characteristics:

  • Highly aligned: All faces are roughly centered and similarly scaled, making the learning problem more tractable
  • Attribute labels: Useful for conditional generation (e.g., generate faces with glasses, blonde hair)
  • Larger resolution: Commonly used at 64x64 or 128x128, allowing more detail than CIFAR-10
🐍python
1from torchvision.datasets import CelebA
2
3def get_celeba_dataset(
4    root: str = "./data",
5    split: str = "train",
6    image_size: int = 64,
7):
8    """
9    Load CelebA dataset with center crop and resize.
10
11    The original images are 178x218. We center crop to 178x178
12    (square) then resize to the target resolution.
13    """
14    transform = transforms.Compose([
15        transforms.CenterCrop(178),  # Crop to square
16        transforms.Resize(image_size),
17        transforms.RandomHorizontalFlip(),
18        transforms.ToTensor(),
19        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
20    ])
21
22    dataset = CelebA(
23        root=root,
24        split=split,  # "train", "valid", or "test"
25        transform=transform,
26        download=True,
27    )
28
29    return dataset
30
31# Note: CelebA download can be slow and may require manual steps
32# due to Google Drive rate limiting. Consider downloading manually.

ImageNet: The Ultimate Challenge

ImageNet (ILSVRC 2012) contains over 1.2 million training images across 1,000 classes. Images vary in size and are typically resized to 256x256 or 512x512. This is the gold standard for evaluating generative models at scale:

  • Diversity: 1,000 categories from dogs to vehicles to household objects
  • Scale: Requires significant compute (days to weeks of training on multiple GPUs)
  • State-of-the-art benchmark: The best diffusion models (like DiT, EDM) compete on ImageNet metrics

Resource Warning

Training a diffusion model on ImageNet at 256x256 resolution typically requires 8+ A100 GPUs and weeks of training time. Start with simpler datasets unless you have access to significant compute resources.
🐍python
1from torchvision.datasets import ImageFolder
2import os
3
4def get_imagenet_dataset(
5    root: str,
6    split: str = "train",
7    image_size: int = 256,
8):
9    """
10    Load ImageNet dataset.
11
12    Note: ImageNet must be downloaded separately (not available via
13    torchvision's automatic download). The dataset should be organized as:
14
15    root/
16        train/
17            n01440764/  # synset ID
18                ILSVRC2012_val_00000001.JPEG
19                ...
20            n01443537/
21                ...
22        val/
23            ... (same structure)
24    """
25    data_path = os.path.join(root, split)
26
27    if split == "train":
28        transform = transforms.Compose([
29            transforms.Resize(image_size + 32),  # Slightly larger for random crop
30            transforms.RandomCrop(image_size),
31            transforms.RandomHorizontalFlip(),
32            transforms.ToTensor(),
33            transforms.Normalize(
34                mean=[0.485, 0.456, 0.406],  # ImageNet statistics
35                std=[0.229, 0.224, 0.225],
36            ),
37        ])
38    else:
39        transform = transforms.Compose([
40            transforms.Resize(image_size + 32),
41            transforms.CenterCrop(image_size),
42            transforms.ToTensor(),
43            transforms.Normalize(
44                mean=[0.485, 0.456, 0.406],
45                std=[0.229, 0.224, 0.225],
46            ),
47        ])
48
49    dataset = ImageFolder(data_path, transform=transform)
50    return dataset

Dataset Characteristics

Understanding the key characteristics of your dataset helps you set appropriate expectations and hyperparameters:

DatasetResolutionChannelsTrain SizeComplexityTypical FID
MNIST28x281 (grayscale)60,000Low~1-3
Fashion-MNIST28x281 (grayscale)60,000Low-Medium~2-5
CIFAR-1032x323 (RGB)50,000Medium~2-4
CelebA (64)64x643 (RGB)162,770Medium-High~2-5
CelebA (128)128x1283 (RGB)162,770High~3-8
ImageNet (256)256x2563 (RGB)1,281,167Very High~2-5

What Makes a Dataset "Hard"?

Several factors contribute to dataset difficulty:

  1. Resolution: More pixels mean exponentially more parameters and longer training. Going from 32x32 to 64x64 quadruples the pixel count.
  2. Diversity: A dataset of centered faces is easier than one with varied backgrounds, poses, and object types.
  3. Fine details: High-frequency textures (hair, fur, fabric) are notoriously difficult to generate well.
  4. Semantic complexity: Generating coherent objects is harder than abstract patterns.

Choosing the Right Dataset

Your choice of dataset should depend on your goals:

For Learning and Debugging

Recommendation: Start with MNIST. If your implementation is correct, you should see recognizable digits after just a few epochs. This fast feedback loop is invaluable for debugging.
  • MNIST: First implementation, verify architecture works
  • Fashion-MNIST: Slightly harder, good for testing improvements
  • CIFAR-10: Standard benchmark, test real performance

For Research and Comparison

Academic papers typically report results on:

  • CIFAR-10: Unconditional generation, most commonly cited
  • ImageNet: Class-conditional generation, gold standard for scale
  • LSUN: Specific domains (bedrooms, churches, cats)

For Production Applications

Real applications often require custom datasets:

  • Art generation: Curated collections of artwork
  • Medical imaging: X-rays, MRIs, histology slides
  • Design: Fashion, architecture, product images

Working with Custom Datasets

For real applications, you'll often need to create custom datasets. PyTorch makes this straightforward with the Dataset\texttt{Dataset} class:

🐍python
1import os
2from PIL import Image
3from torch.utils.data import Dataset
4from typing import Optional, Callable, List
5import torch
6
7class ImageFolderDataset(Dataset):
8    """
9    Custom dataset for a folder of images.
10
11    Expected structure:
12        root/
13            image1.png
14            image2.jpg
15            image3.jpeg
16            ...
17
18    All common image formats are supported.
19    """
20
21    EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
22
23    def __init__(
24        self,
25        root: str,
26        transform: Optional[Callable] = None,
27        max_samples: Optional[int] = None,
28    ):
29        """
30        Args:
31            root: Path to folder containing images
32            transform: Optional transform to apply to images
33            max_samples: Optional limit on number of samples (for debugging)
34        """
35        self.root = root
36        self.transform = transform
37
38        # Find all valid image files
39        self.image_paths: List[str] = []
40        for filename in sorted(os.listdir(root)):
41            ext = os.path.splitext(filename)[1].lower()
42            if ext in self.EXTENSIONS:
43                self.image_paths.append(os.path.join(root, filename))
44
45        if max_samples is not None:
46            self.image_paths = self.image_paths[:max_samples]
47
48        if len(self.image_paths) == 0:
49            raise ValueError(f"No images found in {root}")
50
51        print(f"Found {len(self.image_paths)} images in {root}")
52
53    def __len__(self) -> int:
54        return len(self.image_paths)
55
56    def __getitem__(self, idx: int) -> torch.Tensor:
57        img_path = self.image_paths[idx]
58
59        # Load image and convert to RGB (handles grayscale, RGBA, etc.)
60        image = Image.open(img_path).convert("RGB")
61
62        if self.transform is not None:
63            image = self.transform(image)
64
65        # For unconditional generation, we only return the image
66        # For conditional, you might return (image, label) or (image, text)
67        return image
68
69
70# Usage example
71transform = transforms.Compose([
72    transforms.Resize(64),
73    transforms.CenterCrop(64),
74    transforms.RandomHorizontalFlip(),
75    transforms.ToTensor(),
76    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
77])
78
79# Create dataset
80dataset = ImageFolderDataset(
81    root="./my_images",
82    transform=transform,
83)

Handling Variable-Size Images

Real-world datasets often contain images of different sizes. Here's how to handle them:

🐍python
1class VariableSizeImageDataset(Dataset):
2    """
3    Dataset that handles images of different sizes by resizing
4    while maintaining aspect ratio.
5    """
6
7    def __init__(
8        self,
9        root: str,
10        target_size: int = 256,
11        crop_mode: str = "center",  # "center" or "random"
12    ):
13        self.root = root
14        self.target_size = target_size
15        self.crop_mode = crop_mode
16
17        # Find images (same as before)
18        self.image_paths = self._find_images(root)
19
20        # Build transform based on crop mode
21        self.transform = self._build_transform()
22
23    def _build_transform(self):
24        """Build transform pipeline that handles variable sizes."""
25        transform_list = [
26            # First, resize so the smaller edge equals target_size
27            transforms.Resize(
28                self.target_size,
29                interpolation=transforms.InterpolationMode.BILINEAR,
30            ),
31        ]
32
33        # Then crop to square
34        if self.crop_mode == "random":
35            transform_list.append(transforms.RandomCrop(self.target_size))
36        else:
37            transform_list.append(transforms.CenterCrop(self.target_size))
38
39        # Standard transforms
40        transform_list.extend([
41            transforms.RandomHorizontalFlip(),
42            transforms.ToTensor(),
43            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
44        ])
45
46        return transforms.Compose(transform_list)
47
48    def _find_images(self, root: str) -> List[str]:
49        extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
50        paths = []
51        for dirpath, _, filenames in os.walk(root):
52            for filename in filenames:
53                if os.path.splitext(filename)[1].lower() in extensions:
54                    paths.append(os.path.join(dirpath, filename))
55        return sorted(paths)
56
57    def __len__(self):
58        return len(self.image_paths)
59
60    def __getitem__(self, idx):
61        image = Image.open(self.image_paths[idx]).convert("RGB")
62        return self.transform(image)

Practical Implementation

Here's a unified function that handles all common datasets:

🐍python
1from torchvision import datasets
2from torch.utils.data import Dataset
3from typing import Union, Tuple
4import torch
5
6def get_dataset(
7    name: str,
8    root: str = "./data",
9    image_size: int = 32,
10    train: bool = True,
11) -> Dataset:
12    """
13    Unified dataset loader for common diffusion model benchmarks.
14
15    Args:
16        name: Dataset name ("mnist", "cifar10", "celeba", "custom")
17        root: Root directory for data
18        image_size: Target image size (will resize/crop if needed)
19        train: Whether to load training or validation split
20
21    Returns:
22        PyTorch Dataset object
23    """
24    name = name.lower()
25
26    # Common transform for all color image datasets
27    color_transform = transforms.Compose([
28        transforms.Resize(image_size),
29        transforms.CenterCrop(image_size),
30        transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x),
31        transforms.ToTensor(),
32        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
33    ])
34
35    # Grayscale transform
36    gray_transform = transforms.Compose([
37        transforms.Resize(image_size),
38        transforms.CenterCrop(image_size),
39        transforms.ToTensor(),
40        transforms.Normalize((0.5,), (0.5,)),
41    ])
42
43    if name == "mnist":
44        return datasets.MNIST(
45            root=root,
46            train=train,
47            download=True,
48            transform=gray_transform,
49        )
50
51    elif name == "fashion_mnist":
52        return datasets.FashionMNIST(
53            root=root,
54            train=train,
55            download=True,
56            transform=gray_transform,
57        )
58
59    elif name == "cifar10":
60        return datasets.CIFAR10(
61            root=root,
62            train=train,
63            download=True,
64            transform=color_transform,
65        )
66
67    elif name == "cifar100":
68        return datasets.CIFAR100(
69            root=root,
70            train=train,
71            download=True,
72            transform=color_transform,
73        )
74
75    elif name == "celeba":
76        split = "train" if train else "valid"
77        # CelebA needs special handling for the crop
78        celeba_transform = transforms.Compose([
79            transforms.CenterCrop(178),
80            transforms.Resize(image_size),
81            transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x),
82            transforms.ToTensor(),
83            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
84        ])
85        return datasets.CelebA(
86            root=root,
87            split=split,
88            download=True,
89            transform=celeba_transform,
90        )
91
92    else:
93        raise ValueError(f"Unknown dataset: {name}")
94
95
96# Usage
97dataset = get_dataset("cifar10", image_size=32, train=True)
98print(f"Loaded {len(dataset)} samples")
99
100# Get a sample
101sample = dataset[0]
102if isinstance(sample, tuple):
103    image, label = sample
104else:
105    image = sample
106
107print(f"Image shape: {image.shape}")
108print(f"Value range: [{image.min():.2f}, {image.max():.2f}]")

Image Value Range

Diffusion models typically work with images normalized to [-1, 1]. This is because the denoising process adds Gaussian noise (centered at 0), and having the data centered at 0 makes the model's job easier. Always verify your normalization!

Key Takeaways

  1. Start simple: Begin with MNIST to verify your implementation, then scale up to CIFAR-10 before attempting larger datasets.
  2. Understand complexity: Dataset difficulty scales with resolution, diversity, and semantic complexity. Choose datasets appropriate for your compute budget.
  3. Normalize to [-1, 1]: Standard practice for diffusion models. This centers data at 0, matching the Gaussian noise distribution.
  4. Use standard benchmarks: CIFAR-10 and ImageNet have established baselines, making it easy to compare your results with published work.
  5. Custom datasets are easy: PyTorch's Dataset class makes it straightforward to work with your own image collections.
Looking Ahead: In the next section, we'll dive deep into data preprocessing, covering normalization strategies, augmentation techniques, and how to handle different image formats and resolutions.