Chapter 7
18 min read
Section 43 of 178

The Dataset Class

Data Loading and Processing

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the Dataset abstraction as the foundation of PyTorch's data loading system
  2. Implement custom datasets using the __len__ and __getitem__ protocol
  3. Choose between map-style and iterable-style datasets based on data access patterns
  4. Apply transforms to preprocess data on-the-fly during training
  5. Handle memory efficiently by loading data lazily rather than all at once
  6. Use built-in datasets from torchvision, torchaudio, and torchtext
Why This Matters: In deep learning, data is everything. A well-designed data pipeline can make the difference between a model that trains efficiently and one that bottlenecks on I/O. The Dataset class is PyTorch's elegant solution for organizing, accessing, and transforming your training data.

The Data Bottleneck

Modern GPUs can process millions of floating-point operations per second. A single forward pass through a neural network takes milliseconds. Yet training often takes hours or days. Why?

The answer is data loading. While your GPU waits, data must be:

  1. Read from disk (slow SSD/HDD access)
  2. Decoded (JPEG decompression, audio decoding)
  3. Preprocessed (resize, normalize, augment)
  4. Batched (combine samples into tensors)
  5. Transferred to GPU (PCIe bandwidth limited)

A naive approach—load all data into memory, then iterate—fails for several reasons:

DatasetSize on DiskLoaded in MemoryProblem
MNIST~50 MB~200 MBFits in RAM ✓
CIFAR-100~180 MB~600 MBFits in RAM ✓
ImageNet~150 GB~500 GBExceeds typical RAM ✗
Common Crawl~250 TB~1 PBImpossible to load ✗

PyTorch's solution is lazy loading: load and preprocess data only when needed, one sample (or batch) at a time. This is the job of the Dataset class.

The Key Insight

Instead of holding all data in memory, a Dataset holds metadata—just enough information to locate and load any single sample on demand. The actual pixels, audio samples, or text tokens stay on disk until the training loop requests them.

Memorybatch_size×sample_size\text{Memory} \propto \text{batch\_size} \times \text{sample\_size}

Not dataset_size×sample_size\text{dataset\_size} \times \text{sample\_size}! This is the magic of lazy evaluation.


The Dataset Abstraction

A Dataset represents a collection of samples, where each sample is typically a (features, label) pair. The abstraction requires answering just two questions:

  1. How many samples are there?__len__()
  2. What is the i-th sample?__getitem__(i)

That's it. With these two methods, PyTorch can:

  • Iterate through all samples in order or randomly
  • Create batches by grouping multiple samples
  • Shuffle data at the start of each epoch
  • Split data into train/validation/test sets
  • Load data in parallel using multiple worker processes
Design Philosophy: Python's "dunder methods" (__len__, __getitem__) make Datasets feel like native Python sequences. You can use len(dataset) and dataset[i] just like with lists.

PyTorch's Dataset Class

PyTorch provides the torch.utils.data.Dataset abstract base class. Here's the minimal contract:

🐍dataset_contract.py
1from torch.utils.data import Dataset
2
3class Dataset:
4    """Abstract base class for all datasets.
5
6    All datasets must subclass this and implement:
7    - __len__: Return the total number of samples
8    - __getitem__: Return a sample given an index
9    """
10
11    def __len__(self) -> int:
12        """Return the total number of samples in the dataset."""
13        raise NotImplementedError
14
15    def __getitem__(self, index: int):
16        """Return the sample at the given index.
17
18        Args:
19            index: Index of sample to retrieve (0 to len-1)
20
21        Returns:
22            Typically a tuple (features, label), but can be any structure
23        """
24        raise NotImplementedError

The Two Essential Methods

__len__(self) → int

Returns the total number of samples. This is called by len(dataset) and is used by the DataLoader to:

  • Calculate the number of batches per epoch
  • Generate random indices for shuffling
  • Implement train/val/test splitting

__getitem__(self, index) → sample

Returns the sample at position index. This is called by dataset[i] and is where the actual data loading happens. The return value is typically a tuple:

🐍getitem_examples.py
1# Classification: (features, label)
2dataset[0]  # → (tensor([0.5, 0.3, ...]), 7)
3
4# Detection: (image, bounding_boxes)
5dataset[0]  # → (tensor(...), [{"bbox": [...], "class": 2}])
6
7# Language modeling: (input_ids, target_ids)
8dataset[0]  # → (tensor([101, 2023, ...]), tensor([2023, 2003, ...]))

Return Type Flexibility

PyTorch doesn't enforce a specific return type. You can return a single tensor, a tuple, a dict, or a namedtuple. The DataLoader will batch whatever structure you return. However, the convention is (features, label) for supervised learning.

Interactive: Dataset Methods

The visualization below shows a dataset with 8 samples. Try calling __len__ to get the size, and __getitem__ with different indices to retrieve samples. Notice how invalid indices raise an error just like Python lists.

Dataset Methods: __len__ and __getitem__

Dataset Contents

idx
image (tensor)
label
0
[0.0, 0.1, 0.9, 0.8, 0.2]
5
1
[0.3, 0.7, 0.5, 0.1, 0.4]
3
2
[0.8, 0.2, 0.6, 0.9, 0.1]
7
3
[0.1, 0.4, 0.3, 0.5, 0.8]
2
4
[0.9, 0.6, 0.2, 0.4, 0.7]
9
5
[0.2, 0.5, 0.8, 0.3, 0.6]
1
6
[0.7, 0.3, 0.4, 0.6, 0.5]
4
7
[0.4, 0.8, 0.1, 0.7, 0.9]
8
This simulates a small dataset with 8 samples. Each sample has a feature tensor and a label.

__len__(self) → int

Returns the total number of samples in the dataset.

len(dataset)

__getitem__(self, index) → Tuple[Tensor, int]

Returns the sample at the given index as (features, label).

dataset[]

Python Implementation

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)  # → 8

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

Quick Check

If a dataset has 1000 samples and you call dataset[1000], what happens?


Map-style vs Iterable-style Datasets

PyTorch supports two dataset paradigms:

Map-style Datasets (Most Common)

Map-style datasets implement __len__ and __getitem__. They support random access—you can jump to any sample directly by index. This enables:

  • Shuffling: Randomly permute indices at epoch start
  • Sampling: Over/under-sample classes for imbalanced data
  • Parallel loading: Each worker loads samples at different indices
🐍map_style.py
1from torch.utils.data import Dataset
2
3class MapStyleDataset(Dataset):
4    """Supports random access via __getitem__."""
5
6    def __init__(self, data_dir: str):
7        self.samples = list(Path(data_dir).glob("*.jpg"))
8
9    def __len__(self):
10        return len(self.samples)
11
12    def __getitem__(self, idx):
13        img = Image.open(self.samples[idx])
14        return transform(img)

Iterable-style Datasets

Some data sources don't support random access:

  • Streaming data from a network connection
  • Data generated on-the-fly (infinite datasets)
  • Data stored in formats that must be read sequentially

For these cases, PyTorch provides IterableDataset:

🐍iterable_style.py
1from torch.utils.data import IterableDataset
2
3class StreamingDataset(IterableDataset):
4    """Can only be iterated, not randomly accessed."""
5
6    def __init__(self, url: str):
7        self.url = url
8
9    def __iter__(self):
10        # Stream data from URL
11        with requests.get(self.url, stream=True) as r:
12            for line in r.iter_lines():
13                yield process(line)
14
15    # Note: No __len__ or __getitem__!
FeatureMap-style (Dataset)Iterable-style (IterableDataset)
Random access✓ dataset[i]✗ Sequential only
Shuffling✓ Built-in support✗ Must shuffle upstream
Length known✓ len(dataset)✗ Often unknown/infinite
Multi-worker✓ Easy (indices)⚠ Complex (data partitioning)
Use caseFiles on diskStreams, infinite data

When to Use Each

Use map-style datasets (regular Dataset) for files on disk—this is 95% of use cases. Use iterable-style datasets only for streaming data or when you truly can't support random access.

Building Custom Datasets

Let's build several practical custom datasets, from simple to complex.

Example 1: In-Memory Dataset

For small datasets that fit in memory, loading all data upfront is fine:

In-Memory Dataset Implementation
🐍in_memory_dataset.py
4Inherit from Dataset

All custom datasets inherit from torch.utils.data.Dataset to get the standard interface that DataLoader expects.

11Constructor Stores Data

We store the pre-loaded tensors as instance attributes. This keeps all data in memory for fast access.

17Validation Check

Always validate that features and labels have matching lengths. Mismatches cause confusing errors later.

21__len__ Returns Total Samples

Simply return the number of samples. The DataLoader uses this to calculate batches and generate shuffled indices.

24__getitem__ Returns One Sample

Return the feature tensor and label at the given index. We use Python's multiple return syntax for a clean tuple.

EXAMPLE
dataset[0] returns (tensor([...]), tensor(2))
34 lines without explanation
1import torch
2from torch.utils.data import Dataset
3from typing import Tuple
4
5class InMemoryDataset(Dataset):
6    """Dataset that stores all data in memory.
7
8    Good for: Small datasets (< 1GB), synthetic data
9    Bad for: Large image/video datasets
10    """
11
12    def __init__(self, features: torch.Tensor, labels: torch.Tensor):
13        """Initialize with pre-loaded tensors.
14
15        Args:
16            features: Tensor of shape (N, *feature_dims)
17            labels: Tensor of shape (N,) for classification
18        """
19        assert len(features) == len(labels), "Size mismatch!"
20        self.features = features
21        self.labels = labels
22
23    def __len__(self) -> int:
24        return len(self.features)
25
26    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
27        return self.features[idx], self.labels[idx]
28
29
30# Create synthetic dataset
31N = 1000
32X = torch.randn(N, 10)  # 1000 samples, 10 features each
33y = torch.randint(0, 3, (N,))  # 3 classes
34
35dataset = InMemoryDataset(X, y)
36
37print(f"Dataset size: {len(dataset)}")
38print(f"Sample shape: {dataset[0][0].shape}")
39print(f"Sample label: {dataset[0][1]}")

Example 2: Lazy-Loading Image Dataset

For large image datasets, load images only when accessed:

Lazy-Loading Image Dataset
🐍image_folder_dataset.py
27Scan Directory Structure

In __init__, we only scan for file paths and build the class mapping. We don't load any images yet!

29Class Name → Integer Mapping

Neural networks output integers, so we map folder names (e.g., 'cat', 'dog') to indices (0, 1). Sorting ensures consistent ordering.

32Store Paths, Not Images

The samples list contains (path, label) tuples, not actual images. This keeps __init__ fast and memory-efficient.

45Lazy Loading in __getitem__

The image is loaded from disk only when __getitem__ is called. This is the key to handling datasets larger than RAM.

48Apply Transform

Transforms (resize, normalize, augment) are applied on-the-fly. Different samples can get different random augmentations!

65 lines without explanation
1import torch
2from torch.utils.data import Dataset
3from PIL import Image
4from pathlib import Path
5from typing import Tuple, Optional, Callable
6
7class ImageFolderDataset(Dataset):
8    """Dataset that loads images lazily from a directory.
9
10    Expected folder structure:
11        root/
12            class_0/
13                img001.jpg
14                img002.jpg
15            class_1/
16                img001.jpg
17                ...
18    """
19
20    def __init__(
21        self,
22        root: str,
23        transform: Optional[Callable] = None,
24        extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png")
25    ):
26        self.root = Path(root)
27        self.transform = transform
28
29        # Build class mapping: folder name -> integer
30        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
31        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
32
33        # Collect all image paths and their labels
34        self.samples = []
35        for class_name in self.classes:
36            class_dir = self.root / class_name
37            for img_path in class_dir.iterdir():
38                if img_path.suffix.lower() in extensions:
39                    label = self.class_to_idx[class_name]
40                    self.samples.append((img_path, label))
41
42        print(f"Found {len(self.samples)} images in {len(self.classes)} classes")
43
44    def __len__(self) -> int:
45        return len(self.samples)
46
47    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
48        img_path, label = self.samples[idx]
49
50        # Lazy loading: only load image when requested
51        image = Image.open(img_path).convert("RGB")
52
53        if self.transform:
54            image = self.transform(image)
55
56        return image, label
57
58
59# Usage with transforms
60from torchvision import transforms
61
62transform = transforms.Compose([
63    transforms.Resize((224, 224)),
64    transforms.ToTensor(),
65    transforms.Normalize(mean=[0.485, 0.456, 0.406],
66                        std=[0.229, 0.224, 0.225])
67])
68
69dataset = ImageFolderDataset("data/train", transform=transform)
70image, label = dataset[0]  # Image loaded here, not in __init__!

Why This Works for Large Datasets

For ImageNet (14M images, 150GB), only storing paths uses ~500MB of RAM. The DataLoader loads images in parallel, keeping the GPU fed while never holding more than a few batches in memory.

Quick Check

Where does the actual disk I/O happen in a lazy-loading dataset?


Interactive: Data Pipeline Flow

Watch how data flows through the loading pipeline. The Dataset provides indexed access, and the DataLoader orchestrates shuffling, batching, and parallel loading.

Data Pipeline Visualizer
Raw Dataset
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 samples, 5 per class
Shuffled Order
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Randomly shuffled
Current Batch
Batch size: 4 | Shape: [4, features]
Processed Batches
0/5 batches
Epoch Progress: 0/5 batches
Total samples seen: 0 / 20

Notice how shuffling reorders the indices at the start of each epoch, but doesn't move the actual data. The DataLoader requests samples in the shuffled order, and the Dataset loads each one on demand.


Transforms and Preprocessing

Transforms modify samples during loading. They're essential for:

  • Standardization: Resize images to consistent dimensions
  • Normalization: Scale pixel values to a standard range
  • Data Augmentation: Random transformations to increase data diversity
  • Format conversion: PIL Image → Tensor

Where to Apply Transforms

Transforms are typically passed to the Dataset and applied in __getitem__:

🐍transform_application.py
1class MyDataset(Dataset):
2    def __init__(self, paths, transform=None):
3        self.paths = paths
4        self.transform = transform
5
6    def __getitem__(self, idx):
7        image = load_image(self.paths[idx])
8
9        # Apply transform if provided
10        if self.transform:
11            image = self.transform(image)
12
13        return image, self.labels[idx]
14
15# Different transforms for train vs val
16train_transform = transforms.Compose([
17    transforms.RandomResizedCrop(224),
18    transforms.RandomHorizontalFlip(),
19    transforms.ToTensor(),
20    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21])
22
23val_transform = transforms.Compose([
24    transforms.Resize(256),
25    transforms.CenterCrop(224),
26    transforms.ToTensor(),
27    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28])
29
30train_dataset = MyDataset(train_paths, transform=train_transform)
31val_dataset = MyDataset(val_paths, transform=val_transform)

Train vs Validation Transforms

Training transforms include random augmentations (RandomCrop, RandomFlip) to increase data diversity. Validation transforms should be deterministic (CenterCrop) for reproducible evaluation. Never apply random augmentations during validation!

Memory Considerations

Understanding memory usage is critical for efficient data loading:

Memory Usage Patterns

ApproachMemory UsageI/O PatternWhen to Use
All in RAMO(N × size)One-time loadSmall datasets (< 1GB)
Lazy loadingO(batch × size)Per-sample I/OLarge datasets
Memory-mappedO(batch × size)On-demand pagingPreprocessed data
Cached loadingO(cache_size × size)First-epoch I/ORepeated epochs, fast SSD

Memory-Mapped Files

For preprocessed data (e.g., extracted features), memory-mapped files offer the best of both worlds:

🐍memmap_dataset.py
1import numpy as np
2from torch.utils.data import Dataset
3
4class MemmapDataset(Dataset):
5    """Memory-mapped dataset for preprocessed features.
6
7    Data appears to be in memory but is actually paged from disk
8    by the OS as needed. Very efficient for sequential access.
9    """
10
11    def __init__(self, features_path: str, labels_path: str):
12        # Memory-map the arrays (doesn't load into RAM)
13        self.features = np.load(features_path, mmap_mode='r')
14        self.labels = np.load(labels_path, mmap_mode='r')
15
16    def __len__(self):
17        return len(self.features)
18
19    def __getitem__(self, idx):
20        # OS handles paging from disk transparently
21        return torch.from_numpy(self.features[idx].copy()), int(self.labels[idx])

Preprocessing Large Datasets

For large datasets that you'll train on repeatedly, consider a two-phase approach: (1) Preprocess all images once and save as .npy files. (2) Use memory-mapped loading during training. This avoids repeated JPEG decoding.

Common Dataset Patterns

Pattern 1: Train/Val/Test Splits

🐍split_dataset.py
1from torch.utils.data import Dataset, random_split
2
3# Create full dataset
4full_dataset = MyDataset(all_data)
5
6# Split with fixed seed for reproducibility
7generator = torch.Generator().manual_seed(42)
8train_size = int(0.8 * len(full_dataset))
9val_size = int(0.1 * len(full_dataset))
10test_size = len(full_dataset) - train_size - val_size
11
12train_dataset, val_dataset, test_dataset = random_split(
13    full_dataset,
14    [train_size, val_size, test_size],
15    generator=generator
16)

Pattern 2: Dataset Subsets

🐍subset_dataset.py
1from torch.utils.data import Subset
2
3# Create dataset with specific indices
4train_indices = [i for i in range(len(dataset)) if i % 5 != 0]
5val_indices = [i for i in range(len(dataset)) if i % 5 == 0]
6
7train_subset = Subset(dataset, train_indices)
8val_subset = Subset(dataset, val_indices)
9
10# Subset wraps original dataset, no data copying!
11print(len(train_subset))  # 80% of original

Pattern 3: Concatenating Datasets

🐍concat_dataset.py
1from torch.utils.data import ConcatDataset
2
3# Combine multiple data sources
4dataset_a = ImageDataset("data/source_a")
5dataset_b = ImageDataset("data/source_b")
6dataset_c = ImageDataset("data/source_c")
7
8combined = ConcatDataset([dataset_a, dataset_b, dataset_c])
9print(len(combined))  # len(a) + len(b) + len(c)

Pattern 4: Label Filtering

🐍filtered_dataset.py
1class FilteredDataset(Dataset):
2    """Dataset filtered to only include certain labels."""
3
4    def __init__(self, base_dataset: Dataset, keep_labels: set):
5        self.base = base_dataset
6        self.indices = [
7            i for i in range(len(base_dataset))
8            if base_dataset[i][1] in keep_labels
9        ]
10
11    def __len__(self):
12        return len(self.indices)
13
14    def __getitem__(self, idx):
15        return self.base[self.indices[idx]]
16
17# Keep only cats and dogs from a multi-class dataset
18binary_dataset = FilteredDataset(full_dataset, keep_labels={0, 1})

PyTorch Built-in Datasets

PyTorch provides ready-to-use datasets for common benchmarks:

PackageDatasetsExamples
torchvision.datasetsComputer visionMNIST, CIFAR-10/100, ImageNet, COCO
torchaudio.datasetsAudio processingLibriSpeech, VCTK, CommonVoice
torchtext.datasetsNLPIMDB, WikiText, AG_NEWS
🐍builtin_datasets.py
1from torchvision.datasets import MNIST, CIFAR10
2from torchvision import transforms
3
4# MNIST: Handwritten digits
5mnist_train = MNIST(
6    root="./data",
7    train=True,
8    download=True,  # Downloads if not present
9    transform=transforms.ToTensor()
10)
11
12print(f"MNIST: {len(mnist_train)} training samples")
13image, label = mnist_train[0]
14print(f"Shape: {image.shape}, Label: {label}")
15
16# CIFAR-10: 10-class image classification
17cifar_train = CIFAR10(
18    root="./data",
19    train=True,
20    download=True,
21    transform=transforms.Compose([
22        transforms.ToTensor(),
23        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24    ])
25)
26
27print(f"CIFAR-10: {len(cifar_train)} training samples")

Use Built-in Datasets for Prototyping

When developing a new model, start with a built-in dataset like CIFAR-10. Once your training pipeline works, swap in your custom dataset. This separates debugging the model from debugging the data loading.

Summary

The Dataset class is PyTorch's abstraction for organizing and accessing training data. It enables efficient, memory-conscious data loading for datasets of any size.

Key Concepts

ConceptKey InsightImplementation
Dataset ProtocolAnswer two questions: how many? and what is i?__len__ and __getitem__
Lazy LoadingLoad data only when accessedStore paths, load in __getitem__
Map-style vs IterableRandom access vs streamingDataset vs IterableDataset
TransformsPreprocess on-the-flyPass to __init__, apply in __getitem__
Memory EfficiencyMemory ∝ batch_size, not dataset_sizeLazy loading + parallel workers

Best Practices

  1. Start with built-in datasets for prototyping, then swap in custom data
  2. Use lazy loading for datasets larger than available RAM
  3. Apply augmentation in transforms for on-the-fly data diversity
  4. Validate dataset length matches between features and labels
  5. Use different transforms for training (random) vs validation (deterministic)

Looking Ahead

The Dataset class provides the "what"—accessing individual samples. In the next section, we'll explore the DataLoader—the "how" of batching, shuffling, and parallel loading that turns a Dataset into an efficient data pipeline.


Exercises

Conceptual Questions

  1. Why does PyTorch use __len__ and __getitem__ instead of requiring datasets to inherit from a list? What benefits does this abstraction provide?
  2. Explain why lazy loading is essential for large datasets. What would happen if you tried to load all of ImageNet (150GB) into memory on a typical machine?
  3. When would you choose an IterableDataset over a regular Dataset? Give two real-world examples.
  4. Why should validation transforms be deterministic while training transforms can be random?

Coding Exercises

  1. CSV Dataset: Implement a Dataset that loads tabular data from a CSV file. The __init__ should read the file and store it in a pandas DataFrame. __getitem__ should return a (features, label) tuple where features is a tensor and label is an integer.
  2. Caching Dataset: Create a wrapper dataset that caches loaded samples. On the first access, load from the base dataset and store in a dict. On subsequent accesses, return from cache. Test with an image dataset.
  3. Balanced Sampler Dataset: Create a dataset wrapper that oversamples minority classes. Given a dataset with imbalanced classes, return samples such that each class has equal representation.

Solution Hints

  • CSV: Use pd.read_csv() in __init__, then index with self.df.iloc[idx] in __getitem__
  • Caching: Use a dict with index as key; check if idx in self.cache before loading
  • Balanced: Compute indices for each class, then oversample by repeating minority indices

Challenge Exercise

Multi-Modal Dataset: Build a dataset that loads synchronized image-text pairs for a vision-language model. Each sample should contain:

  • An image (as a tensor)
  • A text caption (as a list of token IDs)
  • A label indicating if the caption matches the image

Implement negative sampling: with 50% probability, return a mismatched image-caption pair with label=0.

Consider Edge Cases

What happens if your dataset has an odd number of samples and you try to create mismatched pairs? How do you ensure the same image-caption pair isn't accidentally matched?

In the next section, we'll dive deep into the DataLoader—the engine that turns your Dataset into batched, shuffled, parallelized training data.