Learning Objectives
By the end of this section, you will be able to:
- Understand the Dataset abstraction as the foundation of PyTorch's data loading system
- Implement custom datasets using the
__len__and__getitem__protocol - Choose between map-style and iterable-style datasets based on data access patterns
- Apply transforms to preprocess data on-the-fly during training
- Handle memory efficiently by loading data lazily rather than all at once
- Use built-in datasets from torchvision, torchaudio, and torchtext
Why This Matters: In deep learning, data is everything. A well-designed data pipeline can make the difference between a model that trains efficiently and one that bottlenecks on I/O. The Dataset class is PyTorch's elegant solution for organizing, accessing, and transforming your training data.
The Data Bottleneck
Modern GPUs can process millions of floating-point operations per second. A single forward pass through a neural network takes milliseconds. Yet training often takes hours or days. Why?
The answer is data loading. While your GPU waits, data must be:
- Read from disk (slow SSD/HDD access)
- Decoded (JPEG decompression, audio decoding)
- Preprocessed (resize, normalize, augment)
- Batched (combine samples into tensors)
- Transferred to GPU (PCIe bandwidth limited)
A naive approach—load all data into memory, then iterate—fails for several reasons:
| Dataset | Size on Disk | Loaded in Memory | Problem |
|---|---|---|---|
| MNIST | ~50 MB | ~200 MB | Fits in RAM ✓ |
| CIFAR-100 | ~180 MB | ~600 MB | Fits in RAM ✓ |
| ImageNet | ~150 GB | ~500 GB | Exceeds typical RAM ✗ |
| Common Crawl | ~250 TB | ~1 PB | Impossible to load ✗ |
PyTorch's solution is lazy loading: load and preprocess data only when needed, one sample (or batch) at a time. This is the job of the Dataset class.
The Key Insight
Instead of holding all data in memory, a Dataset holds metadata—just enough information to locate and load any single sample on demand. The actual pixels, audio samples, or text tokens stay on disk until the training loop requests them.
Not ! This is the magic of lazy evaluation.
The Dataset Abstraction
A Dataset represents a collection of samples, where each sample is typically a (features, label) pair. The abstraction requires answering just two questions:
- How many samples are there? →
__len__() - What is the i-th sample? →
__getitem__(i)
That's it. With these two methods, PyTorch can:
- Iterate through all samples in order or randomly
- Create batches by grouping multiple samples
- Shuffle data at the start of each epoch
- Split data into train/validation/test sets
- Load data in parallel using multiple worker processes
Design Philosophy: Python's "dunder methods" (__len__,__getitem__) make Datasets feel like native Python sequences. You can uselen(dataset)anddataset[i]just like with lists.
PyTorch's Dataset Class
PyTorch provides the torch.utils.data.Dataset abstract base class. Here's the minimal contract:
1from torch.utils.data import Dataset
2
3class Dataset:
4 """Abstract base class for all datasets.
5
6 All datasets must subclass this and implement:
7 - __len__: Return the total number of samples
8 - __getitem__: Return a sample given an index
9 """
10
11 def __len__(self) -> int:
12 """Return the total number of samples in the dataset."""
13 raise NotImplementedError
14
15 def __getitem__(self, index: int):
16 """Return the sample at the given index.
17
18 Args:
19 index: Index of sample to retrieve (0 to len-1)
20
21 Returns:
22 Typically a tuple (features, label), but can be any structure
23 """
24 raise NotImplementedErrorThe Two Essential Methods
__len__(self) → int
Returns the total number of samples. This is called by len(dataset) and is used by the DataLoader to:
- Calculate the number of batches per epoch
- Generate random indices for shuffling
- Implement train/val/test splitting
__getitem__(self, index) → sample
Returns the sample at position index. This is called by dataset[i] and is where the actual data loading happens. The return value is typically a tuple:
1# Classification: (features, label)
2dataset[0] # → (tensor([0.5, 0.3, ...]), 7)
3
4# Detection: (image, bounding_boxes)
5dataset[0] # → (tensor(...), [{"bbox": [...], "class": 2}])
6
7# Language modeling: (input_ids, target_ids)
8dataset[0] # → (tensor([101, 2023, ...]), tensor([2023, 2003, ...]))Return Type Flexibility
(features, label) for supervised learning.Interactive: Dataset Methods
The visualization below shows a dataset with 8 samples. Try calling __len__ to get the size, and __getitem__ with different indices to retrieve samples. Notice how invalid indices raise an error just like Python lists.
Dataset Contents
__len__(self) → int
Returns the total number of samples in the dataset.
len(dataset)__getitem__(self, index) → Tuple[Tensor, int]
Returns the sample at the given index as (features, label).
dataset[]Python Implementation
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data) # → 8
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]Quick Check
If a dataset has 1000 samples and you call dataset[1000], what happens?
Map-style vs Iterable-style Datasets
PyTorch supports two dataset paradigms:
Map-style Datasets (Most Common)
Map-style datasets implement __len__ and __getitem__. They support random access—you can jump to any sample directly by index. This enables:
- Shuffling: Randomly permute indices at epoch start
- Sampling: Over/under-sample classes for imbalanced data
- Parallel loading: Each worker loads samples at different indices
1from torch.utils.data import Dataset
2
3class MapStyleDataset(Dataset):
4 """Supports random access via __getitem__."""
5
6 def __init__(self, data_dir: str):
7 self.samples = list(Path(data_dir).glob("*.jpg"))
8
9 def __len__(self):
10 return len(self.samples)
11
12 def __getitem__(self, idx):
13 img = Image.open(self.samples[idx])
14 return transform(img)Iterable-style Datasets
Some data sources don't support random access:
- Streaming data from a network connection
- Data generated on-the-fly (infinite datasets)
- Data stored in formats that must be read sequentially
For these cases, PyTorch provides IterableDataset:
1from torch.utils.data import IterableDataset
2
3class StreamingDataset(IterableDataset):
4 """Can only be iterated, not randomly accessed."""
5
6 def __init__(self, url: str):
7 self.url = url
8
9 def __iter__(self):
10 # Stream data from URL
11 with requests.get(self.url, stream=True) as r:
12 for line in r.iter_lines():
13 yield process(line)
14
15 # Note: No __len__ or __getitem__!| Feature | Map-style (Dataset) | Iterable-style (IterableDataset) |
|---|---|---|
| Random access | ✓ dataset[i] | ✗ Sequential only |
| Shuffling | ✓ Built-in support | ✗ Must shuffle upstream |
| Length known | ✓ len(dataset) | ✗ Often unknown/infinite |
| Multi-worker | ✓ Easy (indices) | ⚠ Complex (data partitioning) |
| Use case | Files on disk | Streams, infinite data |
When to Use Each
Dataset) for files on disk—this is 95% of use cases. Use iterable-style datasets only for streaming data or when you truly can't support random access.Building Custom Datasets
Let's build several practical custom datasets, from simple to complex.
Example 1: In-Memory Dataset
For small datasets that fit in memory, loading all data upfront is fine:
Example 2: Lazy-Loading Image Dataset
For large image datasets, load images only when accessed:
Why This Works for Large Datasets
Quick Check
Where does the actual disk I/O happen in a lazy-loading dataset?
Interactive: Data Pipeline Flow
Watch how data flows through the loading pipeline. The Dataset provides indexed access, and the DataLoader orchestrates shuffling, batching, and parallel loading.
Notice how shuffling reorders the indices at the start of each epoch, but doesn't move the actual data. The DataLoader requests samples in the shuffled order, and the Dataset loads each one on demand.
Transforms and Preprocessing
Transforms modify samples during loading. They're essential for:
- Standardization: Resize images to consistent dimensions
- Normalization: Scale pixel values to a standard range
- Data Augmentation: Random transformations to increase data diversity
- Format conversion: PIL Image → Tensor
Where to Apply Transforms
Transforms are typically passed to the Dataset and applied in __getitem__:
1class MyDataset(Dataset):
2 def __init__(self, paths, transform=None):
3 self.paths = paths
4 self.transform = transform
5
6 def __getitem__(self, idx):
7 image = load_image(self.paths[idx])
8
9 # Apply transform if provided
10 if self.transform:
11 image = self.transform(image)
12
13 return image, self.labels[idx]
14
15# Different transforms for train vs val
16train_transform = transforms.Compose([
17 transforms.RandomResizedCrop(224),
18 transforms.RandomHorizontalFlip(),
19 transforms.ToTensor(),
20 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21])
22
23val_transform = transforms.Compose([
24 transforms.Resize(256),
25 transforms.CenterCrop(224),
26 transforms.ToTensor(),
27 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28])
29
30train_dataset = MyDataset(train_paths, transform=train_transform)
31val_dataset = MyDataset(val_paths, transform=val_transform)Train vs Validation Transforms
Memory Considerations
Understanding memory usage is critical for efficient data loading:
Memory Usage Patterns
| Approach | Memory Usage | I/O Pattern | When to Use |
|---|---|---|---|
| All in RAM | O(N × size) | One-time load | Small datasets (< 1GB) |
| Lazy loading | O(batch × size) | Per-sample I/O | Large datasets |
| Memory-mapped | O(batch × size) | On-demand paging | Preprocessed data |
| Cached loading | O(cache_size × size) | First-epoch I/O | Repeated epochs, fast SSD |
Memory-Mapped Files
For preprocessed data (e.g., extracted features), memory-mapped files offer the best of both worlds:
1import numpy as np
2from torch.utils.data import Dataset
3
4class MemmapDataset(Dataset):
5 """Memory-mapped dataset for preprocessed features.
6
7 Data appears to be in memory but is actually paged from disk
8 by the OS as needed. Very efficient for sequential access.
9 """
10
11 def __init__(self, features_path: str, labels_path: str):
12 # Memory-map the arrays (doesn't load into RAM)
13 self.features = np.load(features_path, mmap_mode='r')
14 self.labels = np.load(labels_path, mmap_mode='r')
15
16 def __len__(self):
17 return len(self.features)
18
19 def __getitem__(self, idx):
20 # OS handles paging from disk transparently
21 return torch.from_numpy(self.features[idx].copy()), int(self.labels[idx])Preprocessing Large Datasets
Common Dataset Patterns
Pattern 1: Train/Val/Test Splits
1from torch.utils.data import Dataset, random_split
2
3# Create full dataset
4full_dataset = MyDataset(all_data)
5
6# Split with fixed seed for reproducibility
7generator = torch.Generator().manual_seed(42)
8train_size = int(0.8 * len(full_dataset))
9val_size = int(0.1 * len(full_dataset))
10test_size = len(full_dataset) - train_size - val_size
11
12train_dataset, val_dataset, test_dataset = random_split(
13 full_dataset,
14 [train_size, val_size, test_size],
15 generator=generator
16)Pattern 2: Dataset Subsets
1from torch.utils.data import Subset
2
3# Create dataset with specific indices
4train_indices = [i for i in range(len(dataset)) if i % 5 != 0]
5val_indices = [i for i in range(len(dataset)) if i % 5 == 0]
6
7train_subset = Subset(dataset, train_indices)
8val_subset = Subset(dataset, val_indices)
9
10# Subset wraps original dataset, no data copying!
11print(len(train_subset)) # 80% of originalPattern 3: Concatenating Datasets
1from torch.utils.data import ConcatDataset
2
3# Combine multiple data sources
4dataset_a = ImageDataset("data/source_a")
5dataset_b = ImageDataset("data/source_b")
6dataset_c = ImageDataset("data/source_c")
7
8combined = ConcatDataset([dataset_a, dataset_b, dataset_c])
9print(len(combined)) # len(a) + len(b) + len(c)Pattern 4: Label Filtering
1class FilteredDataset(Dataset):
2 """Dataset filtered to only include certain labels."""
3
4 def __init__(self, base_dataset: Dataset, keep_labels: set):
5 self.base = base_dataset
6 self.indices = [
7 i for i in range(len(base_dataset))
8 if base_dataset[i][1] in keep_labels
9 ]
10
11 def __len__(self):
12 return len(self.indices)
13
14 def __getitem__(self, idx):
15 return self.base[self.indices[idx]]
16
17# Keep only cats and dogs from a multi-class dataset
18binary_dataset = FilteredDataset(full_dataset, keep_labels={0, 1})PyTorch Built-in Datasets
PyTorch provides ready-to-use datasets for common benchmarks:
| Package | Datasets | Examples |
|---|---|---|
| torchvision.datasets | Computer vision | MNIST, CIFAR-10/100, ImageNet, COCO |
| torchaudio.datasets | Audio processing | LibriSpeech, VCTK, CommonVoice |
| torchtext.datasets | NLP | IMDB, WikiText, AG_NEWS |
1from torchvision.datasets import MNIST, CIFAR10
2from torchvision import transforms
3
4# MNIST: Handwritten digits
5mnist_train = MNIST(
6 root="./data",
7 train=True,
8 download=True, # Downloads if not present
9 transform=transforms.ToTensor()
10)
11
12print(f"MNIST: {len(mnist_train)} training samples")
13image, label = mnist_train[0]
14print(f"Shape: {image.shape}, Label: {label}")
15
16# CIFAR-10: 10-class image classification
17cifar_train = CIFAR10(
18 root="./data",
19 train=True,
20 download=True,
21 transform=transforms.Compose([
22 transforms.ToTensor(),
23 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24 ])
25)
26
27print(f"CIFAR-10: {len(cifar_train)} training samples")Use Built-in Datasets for Prototyping
Summary
The Dataset class is PyTorch's abstraction for organizing and accessing training data. It enables efficient, memory-conscious data loading for datasets of any size.
Key Concepts
| Concept | Key Insight | Implementation |
|---|---|---|
| Dataset Protocol | Answer two questions: how many? and what is i? | __len__ and __getitem__ |
| Lazy Loading | Load data only when accessed | Store paths, load in __getitem__ |
| Map-style vs Iterable | Random access vs streaming | Dataset vs IterableDataset |
| Transforms | Preprocess on-the-fly | Pass to __init__, apply in __getitem__ |
| Memory Efficiency | Memory ∝ batch_size, not dataset_size | Lazy loading + parallel workers |
Best Practices
- Start with built-in datasets for prototyping, then swap in custom data
- Use lazy loading for datasets larger than available RAM
- Apply augmentation in transforms for on-the-fly data diversity
- Validate dataset length matches between features and labels
- Use different transforms for training (random) vs validation (deterministic)
Looking Ahead
The Dataset class provides the "what"—accessing individual samples. In the next section, we'll explore the DataLoader—the "how" of batching, shuffling, and parallel loading that turns a Dataset into an efficient data pipeline.
Exercises
Conceptual Questions
- Why does PyTorch use
__len__and__getitem__instead of requiring datasets to inherit from a list? What benefits does this abstraction provide? - Explain why lazy loading is essential for large datasets. What would happen if you tried to load all of ImageNet (150GB) into memory on a typical machine?
- When would you choose an IterableDataset over a regular Dataset? Give two real-world examples.
- Why should validation transforms be deterministic while training transforms can be random?
Coding Exercises
- CSV Dataset: Implement a Dataset that loads tabular data from a CSV file. The
__init__should read the file and store it in a pandas DataFrame.__getitem__should return a (features, label) tuple where features is a tensor and label is an integer. - Caching Dataset: Create a wrapper dataset that caches loaded samples. On the first access, load from the base dataset and store in a dict. On subsequent accesses, return from cache. Test with an image dataset.
- Balanced Sampler Dataset: Create a dataset wrapper that oversamples minority classes. Given a dataset with imbalanced classes, return samples such that each class has equal representation.
Solution Hints
- CSV: Use
pd.read_csv()in __init__, then index withself.df.iloc[idx]in __getitem__ - Caching: Use a dict with index as key; check
if idx in self.cachebefore loading - Balanced: Compute indices for each class, then oversample by repeating minority indices
Challenge Exercise
Multi-Modal Dataset: Build a dataset that loads synchronized image-text pairs for a vision-language model. Each sample should contain:
- An image (as a tensor)
- A text caption (as a list of token IDs)
- A label indicating if the caption matches the image
Implement negative sampling: with 50% probability, return a mismatched image-caption pair with label=0.
Consider Edge Cases
In the next section, we'll dive deep into the DataLoader—the engine that turns your Dataset into batched, shuffled, parallelized training data.