Learning Objectives
By the end of this section, you will be able to:
- Build custom datasets for any data format including CSV, JSON, HDF5, and multi-modal data
- Handle complex data structures with nested labels, variable-length sequences, and synchronized multi-modal samples
- Implement robust error handling to gracefully handle corrupted files and missing data
- Debug dataset issues systematically using logging, assertions, and validation techniques
- Optimize dataset performance with caching, memory mapping, and parallel preprocessing
- Access remote data from cloud storage (S3, GCS) and HTTP endpoints efficiently
Why This Matters: Real-world data rarely comes in the neat format of MNIST or CIFAR. You'll encounter CSVs with millions of rows, JSON logs with nested structures, image-caption pairs from different sources, and proprietary formats unique to your organization. Mastering custom datasets is what separates toy experiments from production-ready pipelines.
The Big Picture
The Real-World Data Landscape
In production machine learning, data comes in many forms:
| Data Type | Common Formats | Challenges |
|---|---|---|
| Tabular | CSV, Parquet, SQL databases | Mixed types, missing values, large files |
| Images | JPEG, PNG, DICOM, RAW | Variable sizes, metadata, corrupt files |
| Text | JSON, XML, plain text, tokenized | Variable length, encoding issues, vocabulary |
| Audio | WAV, MP3, FLAC | Sample rates, duration, stereo/mono |
| Video | MP4, AVI, frame sequences | Temporal alignment, massive size |
| Multi-modal | Image+text, video+audio | Synchronization, missing modalities |
| Structured | HDF5, NetCDF, Zarr | Hierarchical, lazy loading required |
PyTorch's Dataset class provides a unified interface to wrap all these formats. The key insight is that a Dataset is just a bridge between your data storage and the training loop.
The Dataset Design Philosophy
When designing a custom dataset, think in terms of three questions:
- Discovery: What files/records exist? (computed once in
__init__) - Access: How do I load sample ? (implemented in
__getitem__) - Count: How many samples are there? (implemented in
__len__)
The Dataset maps integer indices to samples. Everything else—batching, shuffling, parallel loading—is handled by the DataLoader.
Why Custom Datasets?
PyTorch provides many built-in datasets (MNIST, CIFAR, ImageNet, etc.), but you'll need custom datasets when:
- Your data format is unique: Proprietary sensors, custom logging formats, legacy systems
- Your labels are complex: Bounding boxes, segmentation masks, multiple annotations per sample
- Your data spans multiple sources: Images from S3, labels from PostgreSQL, metadata from JSON
- You need custom preprocessing: Domain-specific normalization, on-the-fly augmentation, feature engineering
- Memory constraints require lazy loading: Data too large to fit in RAM
Key Insight: Think of a Dataset as an adapter. It adapts your specific data format to the generic interface that PyTorch expects. Write the adapter once, and the entire PyTorch ecosystem (DataLoader, DistributedSampler, etc.) works with your data.
The Dataset Contract
Every custom dataset must satisfy a simple contract:
1from torch.utils.data import Dataset
2from typing import Any, Tuple
3
4class CustomDataset(Dataset):
5 """Base template for all custom datasets."""
6
7 def __init__(self, data_source: Any, transform=None, **kwargs):
8 """Initialize the dataset.
9
10 This is where you:
11 1. Parse metadata (file paths, indices, labels)
12 2. Store configuration (transforms, options)
13 3. Optionally validate the data source
14
15 IMPORTANT: Do NOT load actual data here for large datasets!
16 """
17 self.data_source = data_source
18 self.transform = transform
19 # Discover samples: build list of (path, label) or similar
20 self.samples = self._discover_samples()
21
22 def _discover_samples(self) -> list:
23 """Scan data source and return list of sample identifiers."""
24 raise NotImplementedError
25
26 def __len__(self) -> int:
27 """Return total number of samples."""
28 return len(self.samples)
29
30 def __getitem__(self, idx: int) -> Tuple[Any, Any]:
31 """Load and return sample at index idx.
32
33 This is where you:
34 1. Load data from disk/database/network
35 2. Apply transforms
36 3. Return (features, label) tuple
37 """
38 # Load sample
39 sample = self._load_sample(self.samples[idx])
40
41 # Apply transform if provided
42 if self.transform:
43 sample = self.transform(sample)
44
45 return sample
46
47 def _load_sample(self, sample_id: Any) -> Tuple[Any, Any]:
48 """Load a single sample given its identifier."""
49 raise NotImplementedErrorThe Lazy Loading Principle
The most important design principle for datasets is lazy loading:
| Method | What To Do | What NOT To Do |
|---|---|---|
| __init__ | Store paths, indices, metadata | Load images, videos, large arrays |
| __len__ | Return cached length | Scan filesystem, count records |
| __getitem__ | Load single sample on-demand | Load multiple samples, cache everything |
Lazy loading ensures memory usage scales with batch size, not dataset size:
Not where is the dataset size.
CSV and Tabular Data
Tabular data is ubiquitous in machine learning. Let's build a robust CSV dataset.
Basic CSV Dataset
Large CSV with Lazy Loading
For CSVs too large to fit in memory, use lazy loading:
1import pandas as pd
2import torch
3from torch.utils.data import Dataset
4
5class LargeCSVDataset(Dataset):
6 """Memory-efficient dataset for large CSV files.
7
8 Uses chunked reading and lazy loading to handle files
9 larger than available RAM.
10 """
11
12 def __init__(self, csv_path: str, chunksize: int = 10000):
13 self.csv_path = csv_path
14
15 # Count total rows without loading entire file
16 self.length = sum(1 for _ in open(csv_path)) - 1 # -1 for header
17
18 # Read header to get column names
19 self.columns = pd.read_csv(csv_path, nrows=0).columns.tolist()
20
21 # For truly large files, use dask or vaex instead
22 # self.df = dask.dataframe.read_csv(csv_path)
23
24 def __len__(self) -> int:
25 return self.length
26
27 def __getitem__(self, idx: int) -> torch.Tensor:
28 # Read only the row we need
29 # Note: This is slow for random access!
30 # Consider using parquet or HDF5 for large datasets
31 row = pd.read_csv(
32 self.csv_path,
33 skiprows=range(1, idx + 1), # Skip header + previous rows
34 nrows=1,
35 header=0,
36 names=self.columns
37 )
38 return torch.tensor(row.values[0], dtype=torch.float32)Random Access Performance
Quick Check
Why do we convert features to tensors in __init__ for the basic CSVDataset, but not for the LargeCSVDataset?
JSON and Nested Data
JSON is common for web data, API responses, and annotation files. Its nested structure requires careful handling.
COCO-Style Annotation Dataset
The COCO dataset format is widely used for object detection and segmentation. Let's implement a dataset that loads it:
JSONL (JSON Lines) for Streaming Data
For large datasets, JSONL format (one JSON object per line) is more efficient:
1import json
2import torch
3from torch.utils.data import Dataset
4
5class JSONLDataset(Dataset):
6 """Dataset for JSON Lines format (one JSON object per line).
7
8 Efficient for large datasets as we can seek to specific lines
9 without loading the entire file.
10 """
11
12 def __init__(self, jsonl_path: str):
13 self.jsonl_path = jsonl_path
14
15 # Build index of line byte offsets for random access
16 self.line_offsets = [0]
17 with open(jsonl_path, 'rb') as f:
18 for line in f:
19 self.line_offsets.append(f.tell())
20 self.line_offsets.pop() # Remove last offset (EOF)
21
22 print(f"Indexed {len(self.line_offsets)} records")
23
24 def __len__(self) -> int:
25 return len(self.line_offsets)
26
27 def __getitem__(self, idx: int) -> dict:
28 with open(self.jsonl_path, 'rb') as f:
29 f.seek(self.line_offsets[idx])
30 line = f.readline()
31 return json.loads(line)
32
33
34# For text classification
35class TextClassificationDataset(JSONLDataset):
36 """JSONL dataset for text classification.
37
38 Expected format: {"text": "...", "label": 0}
39 """
40
41 def __init__(self, jsonl_path: str, tokenizer, max_length: int = 512):
42 super().__init__(jsonl_path)
43 self.tokenizer = tokenizer
44 self.max_length = max_length
45
46 def __getitem__(self, idx: int) -> tuple:
47 record = super().__getitem__(idx)
48
49 # Tokenize text
50 encoding = self.tokenizer(
51 record['text'],
52 max_length=self.max_length,
53 padding='max_length',
54 truncation=True,
55 return_tensors='pt'
56 )
57
58 return {
59 'input_ids': encoding['input_ids'].squeeze(0),
60 'attention_mask': encoding['attention_mask'].squeeze(0),
61 'labels': torch.tensor(record['label'], dtype=torch.long)
62 }JSONL Indexing
__init__, we can seek() directly to any line in __getitem__. This gives O(1) random access instead of O(n) sequential scanning. This technique works for any line-oriented format.Multi-Modal Datasets
Multi-modal learning combines different data types—images with captions, audio with transcripts, video with text descriptions. The key challenge is synchronization: ensuring corresponding modalities are correctly paired.
Image-Caption Dataset
Video-Text Dataset
Video datasets add temporal complexity. Here's a pattern for video-text pairs:
1import torch
2from torch.utils.data import Dataset
3import cv2
4import numpy as np
5from pathlib import Path
6
7class VideoTextDataset(Dataset):
8 """Dataset for video-text pairs (e.g., video captioning)."""
9
10 def __init__(
11 self,
12 video_dir: str,
13 annotations: dict, # {video_id: text}
14 num_frames: int = 16,
15 frame_size: tuple = (224, 224),
16 tokenizer=None,
17 ):
18 self.video_dir = Path(video_dir)
19 self.annotations = annotations
20 self.video_ids = list(annotations.keys())
21 self.num_frames = num_frames
22 self.frame_size = frame_size
23 self.tokenizer = tokenizer
24
25 def __len__(self) -> int:
26 return len(self.video_ids)
27
28 def __getitem__(self, idx: int):
29 video_id = self.video_ids[idx]
30 video_path = self.video_dir / f"{video_id}.mp4"
31
32 # Load video frames
33 frames = self._load_video_frames(video_path)
34
35 # Get text annotation
36 text = self.annotations[video_id]
37
38 if self.tokenizer:
39 tokens = self.tokenizer(
40 text,
41 padding='max_length',
42 truncation=True,
43 return_tensors='pt'
44 )
45 return frames, tokens['input_ids'].squeeze(0)
46
47 return frames, text
48
49 def _load_video_frames(self, video_path: Path) -> torch.Tensor:
50 """Load evenly-spaced frames from a video."""
51 cap = cv2.VideoCapture(str(video_path))
52 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
54 # Sample frame indices evenly
55 indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
56
57 frames = []
58 for frame_idx in indices:
59 cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
60 ret, frame = cap.read()
61 if ret:
62 # BGR to RGB, resize
63 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64 frame = cv2.resize(frame, self.frame_size)
65 frames.append(frame)
66 else:
67 # Pad with zeros if frame read fails
68 frames.append(np.zeros((*self.frame_size, 3), dtype=np.uint8))
69
70 cap.release()
71
72 # Stack and convert: [T, H, W, C] -> [C, T, H, W]
73 frames = np.stack(frames)
74 frames = torch.from_numpy(frames).float() / 255.0
75 frames = frames.permute(3, 0, 1, 2) # [C, T, H, W]
76
77 return framesVideo Loading Performance
decord or torchvision.io, (3) Loading from SSD, not HDD.Sequence and Time Series Data
Sequence data requires special handling for variable lengths and temporal structure.
Time Series Dataset
Variable-Length Sequences with Padding
1import torch
2from torch.utils.data import Dataset, DataLoader
3from torch.nn.utils.rnn import pad_sequence
4
5class VariableLengthDataset(Dataset):
6 """Dataset for variable-length sequences.
7
8 Returns sequences as-is; padding is done in collate_fn.
9 """
10
11 def __init__(self, sequences: list, labels: list):
12 self.sequences = sequences # List of tensors with different lengths
13 self.labels = labels
14
15 def __len__(self):
16 return len(self.sequences)
17
18 def __getitem__(self, idx):
19 return self.sequences[idx], self.labels[idx]
20
21
22def collate_variable_length(batch):
23 """Custom collate function for variable-length sequences.
24
25 Pads all sequences to the length of the longest in the batch.
26 """
27 sequences, labels = zip(*batch)
28
29 # Get lengths before padding
30 lengths = torch.tensor([len(seq) for seq in sequences])
31
32 # Pad sequences to same length
33 padded = pad_sequence(sequences, batch_first=True, padding_value=0)
34
35 # Stack labels
36 labels = torch.stack(labels)
37
38 return padded, labels, lengths
39
40
41# Usage
42sequences = [
43 torch.randn(10, 32), # 10 timesteps
44 torch.randn(25, 32), # 25 timesteps
45 torch.randn(15, 32), # 15 timesteps
46]
47labels = [torch.tensor(0), torch.tensor(1), torch.tensor(0)]
48
49dataset = VariableLengthDataset(sequences, labels)
50
51# Must use custom collate_fn
52loader = DataLoader(
53 dataset,
54 batch_size=3,
55 collate_fn=collate_variable_length
56)
57
58for padded, labels, lengths in loader:
59 print(f"Padded shape: {padded.shape}") # [3, 25, 32] (padded to longest)
60 print(f"Lengths: {lengths}") # tensor([10, 25, 15])Pack Padded Sequences
lengths to pack_padded_sequence() so the RNN ignores padding tokens. This is more efficient and prevents the model from learning padding patterns.HDF5 and Large Files
HDF5 (Hierarchical Data Format) is designed for large numerical datasets. It supports lazy loading, compression, and efficient random access.
HDF5 and DataLoader Workers
num_workers>0, each worker must open its own file handle. The _get_file() pattern handles this by opening lazily in each worker process.Quick Check
Why does the HDF5Dataset open the file lazily in __getitem__ rather than in __init__?
Remote and Cloud Data
Training data often lives in cloud storage (S3, GCS, Azure Blob) or behind HTTP APIs. Here's how to build datasets that access remote data efficiently.
S3 Dataset with Smart Caching
1import torch
2from torch.utils.data import Dataset
3import boto3
4from PIL import Image
5import io
6from pathlib import Path
7import hashlib
8from typing import Optional
9
10class S3ImageDataset(Dataset):
11 """Dataset that loads images from Amazon S3 with local caching.
12
13 Caching strategy:
14 - First access: download from S3, save to local cache
15 - Subsequent accesses: load from local cache
16 - Cache persists across training runs
17 """
18
19 def __init__(
20 self,
21 bucket: str,
22 prefix: str,
23 labels: dict, # {s3_key: label}
24 cache_dir: Optional[str] = None,
25 transform=None,
26 ):
27 self.bucket = bucket
28 self.labels = labels
29 self.keys = list(labels.keys())
30 self.transform = transform
31
32 # Set up S3 client
33 self.s3 = boto3.client('s3')
34
35 # Set up local cache
36 if cache_dir:
37 self.cache_dir = Path(cache_dir)
38 self.cache_dir.mkdir(parents=True, exist_ok=True)
39 else:
40 self.cache_dir = None
41
42 print(f"S3 dataset: {len(self.keys)} images from s3://{bucket}/{prefix}")
43
44 def __len__(self) -> int:
45 return len(self.keys)
46
47 def _get_cache_path(self, s3_key: str) -> Path:
48 """Generate cache path from S3 key."""
49 # Hash the key to handle special characters and long paths
50 key_hash = hashlib.md5(s3_key.encode()).hexdigest()
51 extension = Path(s3_key).suffix
52 return self.cache_dir / f"{key_hash}{extension}"
53
54 def __getitem__(self, idx: int):
55 s3_key = self.keys[idx]
56 label = self.labels[s3_key]
57
58 # Try local cache first
59 if self.cache_dir:
60 cache_path = self._get_cache_path(s3_key)
61 if cache_path.exists():
62 image = Image.open(cache_path).convert('RGB')
63 else:
64 # Download from S3
65 image = self._download_from_s3(s3_key)
66 # Save to cache
67 image.save(cache_path)
68 else:
69 # No caching, always download
70 image = self._download_from_s3(s3_key)
71
72 if self.transform:
73 image = self.transform(image)
74
75 return image, torch.tensor(label, dtype=torch.long)
76
77 def _download_from_s3(self, s3_key: str) -> Image.Image:
78 """Download image from S3 into memory."""
79 response = self.s3.get_object(Bucket=self.bucket, Key=s3_key)
80 image_bytes = response['Body'].read()
81 return Image.open(io.BytesIO(image_bytes)).convert('RGB')
82
83
84# For truly streaming (no local storage), use:
85class S3StreamingDataset(Dataset):
86 """Dataset that streams directly from S3 without caching.
87
88 Use when: cache storage is limited, data changes frequently.
89 """
90
91 def __init__(self, bucket: str, keys: list, labels: list, transform=None):
92 self.bucket = bucket
93 self.keys = keys
94 self.labels = labels
95 self.transform = transform
96 self.s3 = None # Lazy initialization for multiprocessing
97
98 def _get_s3_client(self):
99 if self.s3 is None:
100 self.s3 = boto3.client('s3')
101 return self.s3
102
103 def __len__(self):
104 return len(self.keys)
105
106 def __getitem__(self, idx):
107 s3 = self._get_s3_client()
108 response = s3.get_object(Bucket=self.bucket, Key=self.keys[idx])
109 image_bytes = response['Body'].read()
110 image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
111
112 if self.transform:
113 image = self.transform(image)
114
115 return image, torch.tensor(self.labels[idx])HTTP-Based Dataset
1import torch
2from torch.utils.data import Dataset
3import requests
4from PIL import Image
5import io
6from typing import List, Dict
7import time
8
9class HTTPImageDataset(Dataset):
10 """Dataset that fetches images from HTTP URLs.
11
12 Includes retry logic and rate limiting for robustness.
13 """
14
15 def __init__(
16 self,
17 urls: List[str],
18 labels: List[int],
19 transform=None,
20 max_retries: int = 3,
21 retry_delay: float = 1.0,
22 timeout: float = 10.0,
23 ):
24 self.urls = urls
25 self.labels = labels
26 self.transform = transform
27 self.max_retries = max_retries
28 self.retry_delay = retry_delay
29 self.timeout = timeout
30
31 # Create session for connection pooling
32 self.session = requests.Session()
33
34 def __len__(self):
35 return len(self.urls)
36
37 def __getitem__(self, idx):
38 url = self.urls[idx]
39 label = self.labels[idx]
40
41 # Retry loop
42 for attempt in range(self.max_retries):
43 try:
44 response = self.session.get(url, timeout=self.timeout)
45 response.raise_for_status()
46
47 image = Image.open(io.BytesIO(response.content)).convert('RGB')
48
49 if self.transform:
50 image = self.transform(image)
51
52 return image, torch.tensor(label, dtype=torch.long)
53
54 except Exception as e:
55 if attempt < self.max_retries - 1:
56 time.sleep(self.retry_delay * (attempt + 1))
57 else:
58 # Return placeholder on final failure
59 print(f"Failed to load {url}: {e}")
60 placeholder = Image.new('RGB', (224, 224), color='gray')
61 if self.transform:
62 placeholder = self.transform(placeholder)
63 return placeholder, torch.tensor(label)Network Dataset Best Practices
- Always cache when possible—network I/O is 100-1000x slower than disk
- Use connection pooling (requests.Session) to reuse TCP connections
- Implement retries with exponential backoff for transient failures
- Return placeholders on failure rather than crashing the training loop
Interactive: Data Pipeline
The visualization below shows how custom datasets integrate into the PyTorch data pipeline. Watch how your Dataset provides samples to the DataLoader, which handles batching, shuffling, and parallel loading.
Key observations:
- Dataset provides indexed access to individual samples
- DataLoader orchestrates batching and shuffling
- Worker processes load data in parallel to keep the GPU busy
- Your custom
__getitem__runs in worker processes, not the main process
Error Handling and Debugging
Real datasets have corrupted files, missing data, and edge cases. Robust error handling is essential.
Graceful Error Handling
Debugging Tips
1# 1. Test a single sample first
2dataset = MyDataset(...)
3sample = dataset[0]
4print(f"Sample type: {type(sample)}")
5print(f"Sample shapes: {[x.shape if hasattr(x, 'shape') else x for x in sample]}")
6
7# 2. Test boundary conditions
8print(f"First sample: {dataset[0]}")
9print(f"Last sample: {dataset[len(dataset) - 1]}")
10
11# Try out-of-bounds (should raise IndexError)
12try:
13 dataset[len(dataset)]
14except IndexError:
15 print("IndexError raised correctly!")
16
17# 3. Test with small DataLoader
18from torch.utils.data import DataLoader
19
20loader = DataLoader(dataset, batch_size=4, num_workers=0)
21batch = next(iter(loader))
22print(f"Batch shapes: {[x.shape for x in batch]}")
23
24# 4. Test with multiple workers (catches multiprocessing issues)
25loader = DataLoader(dataset, batch_size=4, num_workers=2)
26for i, batch in enumerate(loader):
27 if i >= 3:
28 break
29 print(f"Batch {i} loaded successfully")
30
31# 5. Profile loading time
32import time
33times = []
34for i in range(100):
35 start = time.time()
36 _ = dataset[i]
37 times.append(time.time() - start)
38
39print(f"Mean load time: {sum(times)/len(times)*1000:.2f}ms")
40print(f"Max load time: {max(times)*1000:.2f}ms")Testing Your Datasets
Datasets are critical infrastructure. Test them thoroughly:
1import pytest
2import torch
3from my_dataset import MyCustomDataset
4
5class TestMyCustomDataset:
6 """Unit tests for custom dataset."""
7
8 @pytest.fixture
9 def dataset(self):
10 """Create dataset for testing."""
11 return MyCustomDataset(test_data_path, transform=None)
12
13 def test_length(self, dataset):
14 """Dataset reports correct length."""
15 assert len(dataset) > 0
16 assert len(dataset) == expected_length
17
18 def test_getitem_returns_tuple(self, dataset):
19 """__getitem__ returns (features, label) tuple."""
20 sample = dataset[0]
21 assert isinstance(sample, tuple)
22 assert len(sample) == 2
23
24 def test_feature_shape(self, dataset):
25 """Features have expected shape."""
26 features, _ = dataset[0]
27 assert features.shape == expected_feature_shape
28
29 def test_feature_dtype(self, dataset):
30 """Features are float tensors."""
31 features, _ = dataset[0]
32 assert features.dtype == torch.float32
33
34 def test_label_dtype(self, dataset):
35 """Labels are integer tensors."""
36 _, label = dataset[0]
37 assert label.dtype == torch.long
38
39 def test_label_range(self, dataset):
40 """Labels are within valid range."""
41 for i in range(min(100, len(dataset))):
42 _, label = dataset[i]
43 assert 0 <= label < num_classes
44
45 def test_reproducibility(self, dataset):
46 """Same index returns same sample (no side effects)."""
47 sample1 = dataset[0]
48 sample2 = dataset[0]
49 assert torch.equal(sample1[0], sample2[0])
50 assert sample1[1] == sample2[1]
51
52 def test_boundary_indices(self, dataset):
53 """First and last indices work correctly."""
54 _ = dataset[0]
55 _ = dataset[len(dataset) - 1]
56
57 def test_out_of_bounds(self, dataset):
58 """Out-of-bounds index raises IndexError."""
59 with pytest.raises(IndexError):
60 _ = dataset[len(dataset)]
61
62 def test_negative_index(self, dataset):
63 """Negative indices work (Python convention)."""
64 sample = dataset[-1]
65 last_sample = dataset[len(dataset) - 1]
66 assert torch.equal(sample[0], last_sample[0])
67
68 def test_with_dataloader(self, dataset):
69 """Dataset works with DataLoader."""
70 from torch.utils.data import DataLoader
71 loader = DataLoader(dataset, batch_size=4, num_workers=0)
72 batch = next(iter(loader))
73 assert batch[0].shape[0] == 4 # Batch size
74
75 def test_with_multiprocessing(self, dataset):
76 """Dataset works with multiple workers."""
77 from torch.utils.data import DataLoader
78 loader = DataLoader(dataset, batch_size=4, num_workers=2)
79 # Load a few batches to verify multiprocessing works
80 for i, batch in enumerate(loader):
81 if i >= 3:
82 breakTest Data Fixtures
Performance Optimization
Data loading can bottleneck training. Here are optimization strategies:
Optimization Strategies
| Strategy | When to Use | Speedup |
|---|---|---|
| num_workers > 0 | Always (CPU not saturated) | 2-8x |
| pin_memory=True | Using GPU | 1.1-1.3x |
| prefetch_factor > 2 | Slow __getitem__ | Variable |
| SSD storage | Lazy loading from disk | 5-20x vs HDD |
| Memory mapping | Large preprocessed data | Near-RAM speed |
| Pre-caching | Small enough to cache | 10-100x |
Optimization Example
1from torch.utils.data import DataLoader
2
3# Baseline: Single-threaded, no optimization
4slow_loader = DataLoader(dataset, batch_size=32)
5
6# Optimized: Multi-worker, pinned memory, prefetching
7fast_loader = DataLoader(
8 dataset,
9 batch_size=32,
10 shuffle=True,
11 num_workers=4, # Parallel data loading
12 pin_memory=True, # Faster GPU transfer
13 prefetch_factor=2, # Prefetch 2 batches per worker
14 persistent_workers=True # Don't respawn workers each epoch
15)
16
17# For GPU training, monitor utilization
18# If GPU utilization < 90%, data loading is likely the bottleneck
19
20# Profiling data loading
21import time
22
23def benchmark_loader(loader, num_batches=100):
24 start = time.time()
25 for i, batch in enumerate(loader):
26 if i >= num_batches:
27 break
28 # Simulate model forward pass
29 _ = batch[0].cuda()
30 elapsed = time.time() - start
31 throughput = num_batches * loader.batch_size / elapsed
32 print(f"Throughput: {throughput:.1f} samples/sec")
33
34benchmark_loader(slow_loader)
35benchmark_loader(fast_loader)Memory-Efficient Caching
1from torch.utils.data import Dataset
2from functools import lru_cache
3import torch
4
5class CachedDataset(Dataset):
6 """Dataset with LRU caching for frequently accessed samples."""
7
8 def __init__(self, base_dataset, cache_size=10000):
9 self.base = base_dataset
10 self.cache_size = cache_size
11
12 def __len__(self):
13 return len(self.base)
14
15 @lru_cache(maxsize=10000)
16 def __getitem__(self, idx):
17 # Warning: This caches transformed data!
18 # Don't use with random augmentations
19 return self.base[idx]
20
21 def clear_cache(self):
22 """Clear cache between epochs if using augmentation."""
23 self.__getitem__.cache_clear()
24
25
26# For random augmentation, cache raw data only
27class SmartCachedDataset(Dataset):
28 """Cache raw data, apply transforms on-the-fly."""
29
30 def __init__(self, base_dataset, cache_size=10000):
31 self.base = base_dataset
32 self.transform = base_dataset.transform
33 self.base.transform = None # Disable transform in base
34 self.cache = {}
35 self.cache_size = cache_size
36
37 def __len__(self):
38 return len(self.base)
39
40 def __getitem__(self, idx):
41 if idx not in self.cache:
42 if len(self.cache) >= self.cache_size:
43 # Simple eviction: remove random item
44 self.cache.pop(next(iter(self.cache)))
45 self.cache[idx] = self.base[idx]
46
47 sample = self.cache[idx]
48
49 # Apply transform (with randomness) each time
50 if self.transform:
51 features, label = sample
52 features = self.transform(features)
53 return features, label
54 return sampleCaching and Augmentation
Common Patterns and Recipes
Pattern 1: Preprocessing Pipeline
1# Step 1: Preprocess raw data to optimized format (run once)
2def preprocess_dataset(raw_dir, output_path):
3 """Convert raw images to preprocessed HDF5."""
4 import h5py
5 from PIL import Image
6 from pathlib import Path
7
8 images = list(Path(raw_dir).glob("**/*.jpg"))
9
10 with h5py.File(output_path, 'w') as f:
11 features = f.create_dataset(
12 'features',
13 shape=(len(images), 3, 224, 224),
14 dtype='float32',
15 compression='gzip'
16 )
17
18 for i, img_path in enumerate(images):
19 img = Image.open(img_path).convert('RGB')
20 img = img.resize((224, 224))
21 arr = np.array(img).transpose(2, 0, 1) / 255.0
22 features[i] = arr
23
24 if i % 1000 == 0:
25 print(f"Processed {i}/{len(images)}")
26
27# Step 2: Fast training with preprocessed data
28dataset = HDF5Dataset("preprocessed.h5")Pattern 2: Multi-GPU Dataset Sharding
1from torch.utils.data import Dataset, DistributedSampler
2
3class ShardedDataset(Dataset):
4 """Dataset that loads a specific shard for distributed training."""
5
6 def __init__(self, shard_paths: list, shard_idx: int, num_shards: int):
7 # Each GPU loads different shards
8 self.my_shards = shard_paths[shard_idx::num_shards]
9 self.samples = self._load_shards()
10
11 def _load_shards(self):
12 samples = []
13 for shard_path in self.my_shards:
14 samples.extend(load_shard(shard_path))
15 return samples
16
17 def __len__(self):
18 return len(self.samples)
19
20 def __getitem__(self, idx):
21 return self.samples[idx]
22
23
24# Alternative: Use DistributedSampler (simpler)
25from torch.utils.data.distributed import DistributedSampler
26
27dataset = MyDataset(...)
28sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
29loader = DataLoader(dataset, sampler=sampler, batch_size=32)Pattern 3: On-the-Fly Feature Engineering
1class FeatureEngineeringDataset(Dataset):
2 """Dataset with on-the-fly feature engineering."""
3
4 def __init__(self, df, feature_columns, label_column):
5 self.df = df
6 self.feature_columns = feature_columns
7 self.label_column = label_column
8
9 def __len__(self):
10 return len(self.df)
11
12 def __getitem__(self, idx):
13 row = self.df.iloc[idx]
14
15 # Extract base features
16 features = row[self.feature_columns].values
17
18 # Engineer new features on-the-fly
19 engineered = [
20 row['feature_a'] * row['feature_b'], # Interaction
21 np.log1p(row['feature_c']), # Log transform
22 row['feature_d'] ** 2, # Polynomial
23 (row['feature_e'] - self.mean_e) / self.std_e, # Normalize
24 ]
25
26 all_features = np.concatenate([features, engineered])
27
28 return (
29 torch.tensor(all_features, dtype=torch.float32),
30 torch.tensor(row[self.label_column], dtype=torch.long)
31 )Summary
Custom datasets are the bridge between your data and PyTorch's training infrastructure. They adapt any data format to the standard Dataset interface.
Key Concepts
| Concept | Key Insight | Implementation |
|---|---|---|
| Dataset Contract | Answer 'how many?' and 'what is i?' | __len__ and __getitem__ |
| Lazy Loading | Load only when accessed | Store paths in __init__, load in __getitem__ |
| Multi-Modal | Synchronize modalities | Return dict with all modalities |
| Error Handling | Graceful degradation | Try-except with placeholders |
| Performance | Data loading often bottlenecks | Workers, caching, preprocessing |
Best Practices Checklist
- Keep __init__ lightweight: Only store paths and metadata
- Lazy load in __getitem__: Load data only when accessed
- Handle errors gracefully: Don't crash on one bad file
- Test with DataLoader: Verify multiprocessing works
- Profile before optimizing: Find the actual bottleneck
- Preprocess large datasets: Convert to efficient formats (HDF5, Zarr)
- Cache when appropriate: But never with random augmentation
- Document the data format: Future you will thank present you
Looking Ahead
In the next section, we'll tackle Handling Imbalanced Data—techniques like weighted sampling, oversampling, and class-balanced losses that ensure minority classes get fair treatment during training.
Exercises
Conceptual Questions
- Explain why opening a file handle in
__init__doesn't work with DataLoader'snum_workers > 0. What is the correct pattern? - A colleague proposes loading all images into RAM in
__init__"for speed." Under what conditions is this a good idea? When is it problematic? - Why should random augmentations NOT be applied before caching? How would you design a caching strategy that still allows augmentation?
- Compare the tradeoffs between storing data as individual files vs. a single HDF5 file. When would you prefer each approach?
Coding Exercises
- Audio Dataset: Implement a dataset for audio classification that:
- Loads .wav files lazily
- Resamples to a common sample rate (16kHz)
- Pads or truncates to fixed length (5 seconds)
- Returns spectrogram as features
- Database Dataset: Create a dataset that:
- Loads records from a SQLite database
- Supports lazy loading with efficient indexing
- Handles NULL values gracefully
- Streaming Dataset: Implement an
IterableDatasetthat:- Reads from a Kafka topic or similar stream
- Buffers messages for batching
- Handles reconnection on failure
Solution Hints
- Audio: Use
torchaudio.load()andtorchaudio.transforms.Resample() - Database: Use
sqlite3withSELECT ... LIMIT 1 OFFSET idx - Streaming: Inherit from
IterableDataset, implement__iter__
Challenge Exercise
Self-Supervised Contrastive Dataset: Build a dataset for SimCLR-style contrastive learning:
- Each
__getitem__returns TWO augmented views of the same image - The two views use different random augmentations
- Include a special collate function that creates positive pairs
- Ensure no data leakage between the two views
This exercise teaches you to think about how the dataset structure supports the learning objective (contrastive learning requires paired samples).
Next, we'll learn to handle Imbalanced Data—ensuring your model doesn't ignore rare but important classes.