Chapter 7
22 min read
Section 46 of 178

Custom Datasets

Data Loading and Processing

Learning Objectives

By the end of this section, you will be able to:

  1. Build custom datasets for any data format including CSV, JSON, HDF5, and multi-modal data
  2. Handle complex data structures with nested labels, variable-length sequences, and synchronized multi-modal samples
  3. Implement robust error handling to gracefully handle corrupted files and missing data
  4. Debug dataset issues systematically using logging, assertions, and validation techniques
  5. Optimize dataset performance with caching, memory mapping, and parallel preprocessing
  6. Access remote data from cloud storage (S3, GCS) and HTTP endpoints efficiently
Why This Matters: Real-world data rarely comes in the neat format of MNIST or CIFAR. You'll encounter CSVs with millions of rows, JSON logs with nested structures, image-caption pairs from different sources, and proprietary formats unique to your organization. Mastering custom datasets is what separates toy experiments from production-ready pipelines.

The Big Picture

The Real-World Data Landscape

In production machine learning, data comes in many forms:

Data TypeCommon FormatsChallenges
TabularCSV, Parquet, SQL databasesMixed types, missing values, large files
ImagesJPEG, PNG, DICOM, RAWVariable sizes, metadata, corrupt files
TextJSON, XML, plain text, tokenizedVariable length, encoding issues, vocabulary
AudioWAV, MP3, FLACSample rates, duration, stereo/mono
VideoMP4, AVI, frame sequencesTemporal alignment, massive size
Multi-modalImage+text, video+audioSynchronization, missing modalities
StructuredHDF5, NetCDF, ZarrHierarchical, lazy loading required

PyTorch's Dataset class provides a unified interface to wrap all these formats. The key insight is that a Dataset is just a bridge between your data storage and the training loop.

The Dataset Design Philosophy

When designing a custom dataset, think in terms of three questions:

  1. Discovery: What files/records exist? (computed once in __init__)
  2. Access: How do I load sample ii? (implemented in __getitem__)
  3. Count: How many samples are there? (implemented in __len__)
Dataset={(xi,yi)i{0,1,,N1}}\text{Dataset} = \{ (x_i, y_i) \mid i \in \{0, 1, \ldots, N-1\} \}

The Dataset maps integer indices to samples. Everything else—batching, shuffling, parallel loading—is handled by the DataLoader.


Why Custom Datasets?

PyTorch provides many built-in datasets (MNIST, CIFAR, ImageNet, etc.), but you'll need custom datasets when:

  • Your data format is unique: Proprietary sensors, custom logging formats, legacy systems
  • Your labels are complex: Bounding boxes, segmentation masks, multiple annotations per sample
  • Your data spans multiple sources: Images from S3, labels from PostgreSQL, metadata from JSON
  • You need custom preprocessing: Domain-specific normalization, on-the-fly augmentation, feature engineering
  • Memory constraints require lazy loading: Data too large to fit in RAM
Key Insight: Think of a Dataset as an adapter. It adapts your specific data format to the generic interface that PyTorch expects. Write the adapter once, and the entire PyTorch ecosystem (DataLoader, DistributedSampler, etc.) works with your data.

The Dataset Contract

Every custom dataset must satisfy a simple contract:

🐍dataset_contract.py
1from torch.utils.data import Dataset
2from typing import Any, Tuple
3
4class CustomDataset(Dataset):
5    """Base template for all custom datasets."""
6
7    def __init__(self, data_source: Any, transform=None, **kwargs):
8        """Initialize the dataset.
9
10        This is where you:
11        1. Parse metadata (file paths, indices, labels)
12        2. Store configuration (transforms, options)
13        3. Optionally validate the data source
14
15        IMPORTANT: Do NOT load actual data here for large datasets!
16        """
17        self.data_source = data_source
18        self.transform = transform
19        # Discover samples: build list of (path, label) or similar
20        self.samples = self._discover_samples()
21
22    def _discover_samples(self) -> list:
23        """Scan data source and return list of sample identifiers."""
24        raise NotImplementedError
25
26    def __len__(self) -> int:
27        """Return total number of samples."""
28        return len(self.samples)
29
30    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
31        """Load and return sample at index idx.
32
33        This is where you:
34        1. Load data from disk/database/network
35        2. Apply transforms
36        3. Return (features, label) tuple
37        """
38        # Load sample
39        sample = self._load_sample(self.samples[idx])
40
41        # Apply transform if provided
42        if self.transform:
43            sample = self.transform(sample)
44
45        return sample
46
47    def _load_sample(self, sample_id: Any) -> Tuple[Any, Any]:
48        """Load a single sample given its identifier."""
49        raise NotImplementedError

The Lazy Loading Principle

The most important design principle for datasets is lazy loading:

MethodWhat To DoWhat NOT To Do
__init__Store paths, indices, metadataLoad images, videos, large arrays
__len__Return cached lengthScan filesystem, count records
__getitem__Load single sample on-demandLoad multiple samples, cache everything

Lazy loading ensures memory usage scales with batch size, not dataset size:

Memory=O(batch_size×sample_size)\text{Memory} = O(\text{batch\_size} \times \text{sample\_size})

Not O(N×sample_size)O(N \times \text{sample\_size}) where NN is the dataset size.


CSV and Tabular Data

Tabular data is ubiquitous in machine learning. Let's build a robust CSV dataset.

Basic CSV Dataset

CSV Dataset Implementation
🐍csv_dataset.py
19Flexible Column Selection

Users can specify exactly which columns to use as features. This is important when CSVs contain irrelevant columns like IDs or timestamps.

24Categorical Encoding

Neural networks need numeric inputs. We automatically one-hot encode specified categorical columns using pandas get_dummies().

43Missing Value Handling

Real data has missing values. We fill them with a configurable value. For production, consider more sophisticated imputation strategies.

58Eager Loading for Small Data

For tabular data that fits in memory, we load everything into tensors upfront. This is faster than lazy loading for repeated epochs.

68Type Conversion

Features become float tensors, labels become long (integer) tensors. This matches PyTorch's loss function expectations.

88 lines without explanation
1import pandas as pd
2import torch
3from torch.utils.data import Dataset
4from pathlib import Path
5from typing import List, Optional, Tuple, Union
6
7class CSVDataset(Dataset):
8    """Dataset for loading tabular data from CSV files.
9
10    Supports:
11    - Feature columns and label column specification
12    - Automatic dtype handling
13    - Optional categorical encoding
14    - Missing value handling
15    """
16
17    def __init__(
18        self,
19        csv_path: Union[str, Path],
20        feature_columns: Optional[List[str]] = None,
21        label_column: Optional[str] = None,
22        categorical_columns: Optional[List[str]] = None,
23        fillna_value: float = 0.0,
24        dtype: torch.dtype = torch.float32,
25    ):
26        """Initialize the CSV dataset.
27
28        Args:
29            csv_path: Path to the CSV file
30            feature_columns: List of column names to use as features
31                            (None = all columns except label)
32            label_column: Column name for labels (None = no labels)
33            categorical_columns: Columns to one-hot encode
34            fillna_value: Value to fill missing data
35            dtype: PyTorch dtype for feature tensor
36        """
37        self.csv_path = Path(csv_path)
38        self.dtype = dtype
39
40        # Load CSV into DataFrame
41        self.df = pd.read_csv(self.csv_path)
42        print(f"Loaded {len(self.df)} rows from {self.csv_path.name}")
43
44        # Handle missing values
45        self.df = self.df.fillna(fillna_value)
46
47        # Encode categorical columns
48        if categorical_columns:
49            self.df = pd.get_dummies(
50                self.df,
51                columns=categorical_columns,
52                drop_first=True
53            )
54
55        # Set up feature and label columns
56        if label_column:
57            self.labels = self.df[label_column].values
58            self.df = self.df.drop(columns=[label_column])
59        else:
60            self.labels = None
61
62        if feature_columns:
63            self.features = self.df[feature_columns].values
64        else:
65            self.features = self.df.values
66
67        # Convert to tensors for faster access
68        self.features = torch.tensor(self.features, dtype=dtype)
69        if self.labels is not None:
70            self.labels = torch.tensor(self.labels, dtype=torch.long)
71
72        print(f"Features shape: {self.features.shape}")
73
74    def __len__(self) -> int:
75        return len(self.features)
76
77    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
78        if self.labels is not None:
79            return self.features[idx], self.labels[idx]
80        return self.features[idx]
81
82
83# Example usage
84dataset = CSVDataset(
85    csv_path="data/train.csv",
86    feature_columns=["age", "income", "credit_score"],
87    label_column="approved",
88    categorical_columns=["gender", "region"],
89    fillna_value=-1.0
90)
91
92features, label = dataset[0]
93print(f"Sample: features={features.shape}, label={label}")

Large CSV with Lazy Loading

For CSVs too large to fit in memory, use lazy loading:

🐍lazy_csv_dataset.py
1import pandas as pd
2import torch
3from torch.utils.data import Dataset
4
5class LargeCSVDataset(Dataset):
6    """Memory-efficient dataset for large CSV files.
7
8    Uses chunked reading and lazy loading to handle files
9    larger than available RAM.
10    """
11
12    def __init__(self, csv_path: str, chunksize: int = 10000):
13        self.csv_path = csv_path
14
15        # Count total rows without loading entire file
16        self.length = sum(1 for _ in open(csv_path)) - 1  # -1 for header
17
18        # Read header to get column names
19        self.columns = pd.read_csv(csv_path, nrows=0).columns.tolist()
20
21        # For truly large files, use dask or vaex instead
22        # self.df = dask.dataframe.read_csv(csv_path)
23
24    def __len__(self) -> int:
25        return self.length
26
27    def __getitem__(self, idx: int) -> torch.Tensor:
28        # Read only the row we need
29        # Note: This is slow for random access!
30        # Consider using parquet or HDF5 for large datasets
31        row = pd.read_csv(
32            self.csv_path,
33            skiprows=range(1, idx + 1),  # Skip header + previous rows
34            nrows=1,
35            header=0,
36            names=self.columns
37        )
38        return torch.tensor(row.values[0], dtype=torch.float32)

Random Access Performance

Lazy loading from CSV is very slow for random access because CSV is a row-oriented format that must be scanned sequentially. For large datasets, convert to Parquet or HDF5 first, then use memory mapping.

Quick Check

Why do we convert features to tensors in __init__ for the basic CSVDataset, but not for the LargeCSVDataset?


JSON and Nested Data

JSON is common for web data, API responses, and annotation files. Its nested structure requires careful handling.

COCO-Style Annotation Dataset

The COCO dataset format is widely used for object detection and segmentation. Let's implement a dataset that loads it:

COCO-Style Object Detection Dataset
🐍coco_dataset.py
35Lookup Tables

We build dictionaries for O(1) access to images and categories by ID. This is crucial when annotations reference images by ID.

39Grouping Annotations

defaultdict(list) groups all annotations by image_id. This way, __getitem__ can retrieve all boxes for an image in O(1) time.

61Lazy Image Loading

Images are loaded only when __getitem__ is called. The __init__ only parses the JSON metadata (lightweight).

67Coordinate Conversion

COCO uses [x, y, width, height] format. Most PyTorch detection models expect [x1, y1, x2, y2]. Always check your model's expected format!

79Dictionary Targets

For object detection, targets are typically dicts with 'boxes', 'labels', and 'image_id'. This matches torchvision's detection models.

83Empty Annotation Handling

Some images may have no objects. We return empty tensors with correct shape rather than None to avoid collation errors.

107 lines without explanation
1import json
2import torch
3from torch.utils.data import Dataset
4from PIL import Image
5from pathlib import Path
6from typing import Dict, List, Any, Optional
7from collections import defaultdict
8
9class COCOStyleDataset(Dataset):
10    """Dataset for COCO-format annotations.
11
12    Expected JSON structure:
13    {
14        "images": [{"id": 1, "file_name": "img.jpg", ...}, ...],
15        "annotations": [{"id": 1, "image_id": 1, "bbox": [...], ...}, ...],
16        "categories": [{"id": 1, "name": "cat"}, ...]
17    }
18    """
19
20    def __init__(
21        self,
22        annotation_file: str,
23        image_dir: str,
24        transform=None,
25        target_transform=None,
26    ):
27        self.image_dir = Path(image_dir)
28        self.transform = transform
29        self.target_transform = target_transform
30
31        # Load and parse JSON
32        with open(annotation_file, 'r') as f:
33            self.coco = json.load(f)
34
35        # Build lookup tables for efficient access
36        self.images = {img['id']: img for img in self.coco['images']}
37        self.categories = {cat['id']: cat['name']
38                          for cat in self.coco['categories']}
39
40        # Group annotations by image
41        self.img_to_anns: Dict[int, List[Dict]] = defaultdict(list)
42        for ann in self.coco['annotations']:
43            self.img_to_anns[ann['image_id']].append(ann)
44
45        # Create ordered list of image IDs for indexing
46        self.image_ids = list(self.images.keys())
47
48        print(f"Loaded {len(self.image_ids)} images with "
49              f"{len(self.coco['annotations'])} annotations")
50
51    def __len__(self) -> int:
52        return len(self.image_ids)
53
54    def __getitem__(self, idx: int) -> tuple:
55        # Get image info
56        img_id = self.image_ids[idx]
57        img_info = self.images[img_id]
58
59        # Load image
60        img_path = self.image_dir / img_info['file_name']
61        image = Image.open(img_path).convert('RGB')
62
63        # Get all annotations for this image
64        anns = self.img_to_anns[img_id]
65
66        # Extract bounding boxes and labels
67        boxes = []
68        labels = []
69        for ann in anns:
70            # COCO format: [x, y, width, height]
71            x, y, w, h = ann['bbox']
72            # Convert to [x1, y1, x2, y2]
73            boxes.append([x, y, x + w, y + h])
74            labels.append(ann['category_id'])
75
76        # Convert to tensors
77        target = {
78            'boxes': torch.tensor(boxes, dtype=torch.float32),
79            'labels': torch.tensor(labels, dtype=torch.int64),
80            'image_id': torch.tensor([img_id]),
81        }
82
83        # Handle images with no annotations
84        if len(boxes) == 0:
85            target['boxes'] = torch.zeros((0, 4), dtype=torch.float32)
86            target['labels'] = torch.zeros((0,), dtype=torch.int64)
87
88        if self.transform:
89            image = self.transform(image)
90        if self.target_transform:
91            target = self.target_transform(target)
92
93        return image, target
94
95
96# Usage
97from torchvision import transforms
98
99transform = transforms.Compose([
100    transforms.Resize((640, 640)),
101    transforms.ToTensor(),
102])
103
104dataset = COCOStyleDataset(
105    annotation_file="data/annotations.json",
106    image_dir="data/images",
107    transform=transform
108)
109
110image, target = dataset[0]
111print(f"Image: {image.shape}")
112print(f"Boxes: {target['boxes'].shape}")
113print(f"Labels: {target['labels']}")

JSONL (JSON Lines) for Streaming Data

For large datasets, JSONL format (one JSON object per line) is more efficient:

🐍jsonl_dataset.py
1import json
2import torch
3from torch.utils.data import Dataset
4
5class JSONLDataset(Dataset):
6    """Dataset for JSON Lines format (one JSON object per line).
7
8    Efficient for large datasets as we can seek to specific lines
9    without loading the entire file.
10    """
11
12    def __init__(self, jsonl_path: str):
13        self.jsonl_path = jsonl_path
14
15        # Build index of line byte offsets for random access
16        self.line_offsets = [0]
17        with open(jsonl_path, 'rb') as f:
18            for line in f:
19                self.line_offsets.append(f.tell())
20        self.line_offsets.pop()  # Remove last offset (EOF)
21
22        print(f"Indexed {len(self.line_offsets)} records")
23
24    def __len__(self) -> int:
25        return len(self.line_offsets)
26
27    def __getitem__(self, idx: int) -> dict:
28        with open(self.jsonl_path, 'rb') as f:
29            f.seek(self.line_offsets[idx])
30            line = f.readline()
31            return json.loads(line)
32
33
34# For text classification
35class TextClassificationDataset(JSONLDataset):
36    """JSONL dataset for text classification.
37
38    Expected format: {"text": "...", "label": 0}
39    """
40
41    def __init__(self, jsonl_path: str, tokenizer, max_length: int = 512):
42        super().__init__(jsonl_path)
43        self.tokenizer = tokenizer
44        self.max_length = max_length
45
46    def __getitem__(self, idx: int) -> tuple:
47        record = super().__getitem__(idx)
48
49        # Tokenize text
50        encoding = self.tokenizer(
51            record['text'],
52            max_length=self.max_length,
53            padding='max_length',
54            truncation=True,
55            return_tensors='pt'
56        )
57
58        return {
59            'input_ids': encoding['input_ids'].squeeze(0),
60            'attention_mask': encoding['attention_mask'].squeeze(0),
61            'labels': torch.tensor(record['label'], dtype=torch.long)
62        }

JSONL Indexing

By storing byte offsets in __init__, we can seek() directly to any line in __getitem__. This gives O(1) random access instead of O(n) sequential scanning. This technique works for any line-oriented format.

Multi-Modal Datasets

Multi-modal learning combines different data types—images with captions, audio with transcripts, video with text descriptions. The key challenge is synchronization: ensuring corresponding modalities are correctly paired.

Image-Caption Dataset

Image-Caption Multi-Modal Dataset
🐍image_caption_dataset.py
36Multiple Captions Per Image

Many datasets have 5+ captions per image. We create (image_id, caption_idx) tuples so each caption is a separate sample, increasing dataset size.

49Negative Sampling Setup

For contrastive learning (like CLIP), we need mismatched pairs. Storing all image IDs enables efficient negative sampling.

78Dictionary Return Format

Multi-modal samples often return dicts rather than tuples for clarity. Keys like 'image', 'input_ids' make downstream code more readable.

83Negative Pair Generation

With probability p, replace the caption with one from a different image. The label changes from 1 (matched) to 0 (mismatched).

123 lines without explanation
1import torch
2from torch.utils.data import Dataset
3from PIL import Image
4from pathlib import Path
5import json
6from typing import Optional, Callable, Tuple, Dict
7import random
8
9class ImageCaptionDataset(Dataset):
10    """Dataset for image-caption pairs.
11
12    Supports:
13    - Multiple captions per image
14    - Negative sampling for contrastive learning
15    - Synchronized image-text transforms
16    """
17
18    def __init__(
19        self,
20        image_dir: str,
21        captions_file: str,
22        tokenizer,
23        max_caption_length: int = 77,
24        image_transform: Optional[Callable] = None,
25        return_negative: bool = False,
26        negative_probability: float = 0.5,
27    ):
28        self.image_dir = Path(image_dir)
29        self.tokenizer = tokenizer
30        self.max_caption_length = max_caption_length
31        self.image_transform = image_transform
32        self.return_negative = return_negative
33        self.negative_probability = negative_probability
34
35        # Load captions: {"image_id": ["caption1", "caption2", ...]}
36        with open(captions_file, 'r') as f:
37            self.captions_data = json.load(f)
38
39        # Build list of (image_id, caption_index) pairs
40        self.samples = []
41        for img_id, captions in self.captions_data.items():
42            for cap_idx, caption in enumerate(captions):
43                self.samples.append((img_id, cap_idx))
44
45        # For negative sampling
46        self.all_image_ids = list(self.captions_data.keys())
47
48        print(f"Loaded {len(self.samples)} image-caption pairs "
49              f"from {len(self.all_image_ids)} images")
50
51    def __len__(self) -> int:
52        return len(self.samples)
53
54    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
55        img_id, cap_idx = self.samples[idx]
56
57        # Load image
58        img_path = self.image_dir / f"{img_id}.jpg"
59        image = Image.open(img_path).convert('RGB')
60
61        if self.image_transform:
62            image = self.image_transform(image)
63
64        # Get caption
65        caption = self.captions_data[img_id][cap_idx]
66
67        # Tokenize caption
68        tokens = self.tokenizer(
69            caption,
70            max_length=self.max_caption_length,
71            padding='max_length',
72            truncation=True,
73            return_tensors='pt'
74        )
75
76        result = {
77            'image': image,
78            'input_ids': tokens['input_ids'].squeeze(0),
79            'attention_mask': tokens['attention_mask'].squeeze(0),
80            'label': torch.tensor(1, dtype=torch.long),  # Matched pair
81        }
82
83        # Negative sampling for contrastive learning
84        if self.return_negative:
85            if random.random() < self.negative_probability:
86                # Return mismatched pair
87                neg_img_id = random.choice(self.all_image_ids)
88                while neg_img_id == img_id:
89                    neg_img_id = random.choice(self.all_image_ids)
90
91                neg_caption = random.choice(self.captions_data[neg_img_id])
92                neg_tokens = self.tokenizer(
93                    neg_caption,
94                    max_length=self.max_caption_length,
95                    padding='max_length',
96                    truncation=True,
97                    return_tensors='pt'
98                )
99                result['input_ids'] = neg_tokens['input_ids'].squeeze(0)
100                result['attention_mask'] = neg_tokens['attention_mask'].squeeze(0)
101                result['label'] = torch.tensor(0, dtype=torch.long)
102
103        return result
104
105
106# Usage for CLIP-style training
107from torchvision import transforms
108from transformers import CLIPTokenizer
109
110tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
111
112image_transform = transforms.Compose([
113    transforms.Resize((224, 224)),
114    transforms.ToTensor(),
115    transforms.Normalize(
116        mean=[0.48145466, 0.4578275, 0.40821073],
117        std=[0.26862954, 0.26130258, 0.27577711]
118    )
119])
120
121dataset = ImageCaptionDataset(
122    image_dir="data/images",
123    captions_file="data/captions.json",
124    tokenizer=tokenizer,
125    image_transform=image_transform,
126    return_negative=True
127)

Video-Text Dataset

Video datasets add temporal complexity. Here's a pattern for video-text pairs:

🐍video_text_dataset.py
1import torch
2from torch.utils.data import Dataset
3import cv2
4import numpy as np
5from pathlib import Path
6
7class VideoTextDataset(Dataset):
8    """Dataset for video-text pairs (e.g., video captioning)."""
9
10    def __init__(
11        self,
12        video_dir: str,
13        annotations: dict,  # {video_id: text}
14        num_frames: int = 16,
15        frame_size: tuple = (224, 224),
16        tokenizer=None,
17    ):
18        self.video_dir = Path(video_dir)
19        self.annotations = annotations
20        self.video_ids = list(annotations.keys())
21        self.num_frames = num_frames
22        self.frame_size = frame_size
23        self.tokenizer = tokenizer
24
25    def __len__(self) -> int:
26        return len(self.video_ids)
27
28    def __getitem__(self, idx: int):
29        video_id = self.video_ids[idx]
30        video_path = self.video_dir / f"{video_id}.mp4"
31
32        # Load video frames
33        frames = self._load_video_frames(video_path)
34
35        # Get text annotation
36        text = self.annotations[video_id]
37
38        if self.tokenizer:
39            tokens = self.tokenizer(
40                text,
41                padding='max_length',
42                truncation=True,
43                return_tensors='pt'
44            )
45            return frames, tokens['input_ids'].squeeze(0)
46
47        return frames, text
48
49    def _load_video_frames(self, video_path: Path) -> torch.Tensor:
50        """Load evenly-spaced frames from a video."""
51        cap = cv2.VideoCapture(str(video_path))
52        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
54        # Sample frame indices evenly
55        indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
56
57        frames = []
58        for frame_idx in indices:
59            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
60            ret, frame = cap.read()
61            if ret:
62                # BGR to RGB, resize
63                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64                frame = cv2.resize(frame, self.frame_size)
65                frames.append(frame)
66            else:
67                # Pad with zeros if frame read fails
68                frames.append(np.zeros((*self.frame_size, 3), dtype=np.uint8))
69
70        cap.release()
71
72        # Stack and convert: [T, H, W, C] -> [C, T, H, W]
73        frames = np.stack(frames)
74        frames = torch.from_numpy(frames).float() / 255.0
75        frames = frames.permute(3, 0, 1, 2)  # [C, T, H, W]
76
77        return frames

Video Loading Performance

Video loading is expensive. For training, consider: (1) Pre-extracting frames to disk, (2) Using optimized loaders like decord or torchvision.io, (3) Loading from SSD, not HDD.

Sequence and Time Series Data

Sequence data requires special handling for variable lengths and temporal structure.

Time Series Dataset

Time Series Forecasting Dataset
🐍time_series_dataset.py
31Online Normalization Statistics

We compute mean/std from the training data and store them. At inference, use the same statistics to normalize new data consistently.

40Sliding Window Counting

Given T timesteps, we can create (T - input_length - forecast_horizon) / stride + 1 windows. This formula determines __len__.

52Window Extraction

Each __getitem__ call extracts a contiguous window. The input window comes first, followed by the target window. They don't overlap.

58Inverse Transform

After prediction, we often need to convert back to original scale for interpretability. This helper method reverses the normalization.

84 lines without explanation
1import torch
2from torch.utils.data import Dataset
3import pandas as pd
4import numpy as np
5from typing import Tuple, Optional
6
7class TimeSeriesDataset(Dataset):
8    """Dataset for time series forecasting.
9
10    Creates sliding window samples from continuous time series data.
11    """
12
13    def __init__(
14        self,
15        data: np.ndarray,  # Shape: [T, features]
16        input_length: int,
17        forecast_horizon: int,
18        stride: int = 1,
19        normalize: bool = True,
20    ):
21        """Initialize time series dataset.
22
23        Args:
24            data: Time series array [timesteps, features]
25            input_length: Number of timesteps for input window
26            forecast_horizon: Number of timesteps to predict
27            stride: Step size between consecutive windows
28            normalize: Whether to z-score normalize
29        """
30        self.input_length = input_length
31        self.forecast_horizon = forecast_horizon
32        self.stride = stride
33
34        # Normalize if requested
35        if normalize:
36            self.mean = data.mean(axis=0)
37            self.std = data.std(axis=0) + 1e-8  # Avoid division by zero
38            data = (data - self.mean) / self.std
39        else:
40            self.mean = None
41            self.std = None
42
43        self.data = torch.tensor(data, dtype=torch.float32)
44
45        # Calculate number of valid windows
46        total_length = input_length + forecast_horizon
47        self.num_samples = (len(data) - total_length) // stride + 1
48
49        print(f"Created {self.num_samples} samples from "
50              f"{len(data)} timesteps")
51
52    def __len__(self) -> int:
53        return self.num_samples
54
55    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
56        start = idx * self.stride
57        input_end = start + self.input_length
58        output_end = input_end + self.forecast_horizon
59
60        x = self.data[start:input_end]      # [input_length, features]
61        y = self.data[input_end:output_end]  # [forecast_horizon, features]
62
63        return x, y
64
65    def inverse_normalize(self, data: torch.Tensor) -> torch.Tensor:
66        """Convert normalized predictions back to original scale."""
67        if self.mean is not None:
68            mean = torch.tensor(self.mean, device=data.device)
69            std = torch.tensor(self.std, device=data.device)
70            return data * std + mean
71        return data
72
73
74# Usage for stock price prediction
75df = pd.read_csv("stock_prices.csv")
76data = df[['open', 'high', 'low', 'close', 'volume']].values
77
78dataset = TimeSeriesDataset(
79    data=data,
80    input_length=60,      # Use 60 days of history
81    forecast_horizon=5,    # Predict next 5 days
82    stride=1,              # Create overlapping windows
83    normalize=True
84)
85
86x, y = dataset[0]
87print(f"Input shape: {x.shape}")   # [60, 5]
88print(f"Target shape: {y.shape}")  # [5, 5]

Variable-Length Sequences with Padding

🐍variable_length_dataset.py
1import torch
2from torch.utils.data import Dataset, DataLoader
3from torch.nn.utils.rnn import pad_sequence
4
5class VariableLengthDataset(Dataset):
6    """Dataset for variable-length sequences.
7
8    Returns sequences as-is; padding is done in collate_fn.
9    """
10
11    def __init__(self, sequences: list, labels: list):
12        self.sequences = sequences  # List of tensors with different lengths
13        self.labels = labels
14
15    def __len__(self):
16        return len(self.sequences)
17
18    def __getitem__(self, idx):
19        return self.sequences[idx], self.labels[idx]
20
21
22def collate_variable_length(batch):
23    """Custom collate function for variable-length sequences.
24
25    Pads all sequences to the length of the longest in the batch.
26    """
27    sequences, labels = zip(*batch)
28
29    # Get lengths before padding
30    lengths = torch.tensor([len(seq) for seq in sequences])
31
32    # Pad sequences to same length
33    padded = pad_sequence(sequences, batch_first=True, padding_value=0)
34
35    # Stack labels
36    labels = torch.stack(labels)
37
38    return padded, labels, lengths
39
40
41# Usage
42sequences = [
43    torch.randn(10, 32),  # 10 timesteps
44    torch.randn(25, 32),  # 25 timesteps
45    torch.randn(15, 32),  # 15 timesteps
46]
47labels = [torch.tensor(0), torch.tensor(1), torch.tensor(0)]
48
49dataset = VariableLengthDataset(sequences, labels)
50
51# Must use custom collate_fn
52loader = DataLoader(
53    dataset,
54    batch_size=3,
55    collate_fn=collate_variable_length
56)
57
58for padded, labels, lengths in loader:
59    print(f"Padded shape: {padded.shape}")  # [3, 25, 32] (padded to longest)
60    print(f"Lengths: {lengths}")  # tensor([10, 25, 15])

Pack Padded Sequences

When using RNNs, pass lengths to pack_padded_sequence() so the RNN ignores padding tokens. This is more efficient and prevents the model from learning padding patterns.

HDF5 and Large Files

HDF5 (Hierarchical Data Format) is designed for large numerical datasets. It supports lazy loading, compression, and efficient random access.

HDF5 Dataset for Large-Scale Data
🐍hdf5_dataset.py
28Metadata Without Loading

HDF5 stores metadata (shape, dtype) separately from data. We can get dataset info without loading a single sample into memory.

36Lazy File Handle

Opening the file once and reusing the handle is faster than opening/closing in every __getitem__. But watch out for multiprocessing!

49Efficient Indexing

HDF5's chunk-based storage enables O(1) random access. Only the chunk containing the requested sample is read from disk.

71Compression for Storage

gzip compression typically achieves 2-5x reduction for float data. There's a small CPU overhead during reading, but disk I/O savings usually dominate.

77Chunking for Partial Reads

Chunking divides data into blocks. Without it, HDF5 must read the entire dataset to access one sample. Always enable for training datasets.

100 lines without explanation
1import h5py
2import torch
3from torch.utils.data import Dataset
4import numpy as np
5
6class HDF5Dataset(Dataset):
7    """Memory-efficient dataset for HDF5 files.
8
9    HDF5 advantages:
10    - Lazy loading: data stays on disk until accessed
11    - Compression: reduce storage by 2-10x
12    - Random access: O(1) access to any sample
13    - Hierarchical: organize related data together
14    """
15
16    def __init__(
17        self,
18        h5_path: str,
19        features_key: str = 'features',
20        labels_key: str = 'labels',
21        transform=None,
22    ):
23        self.h5_path = h5_path
24        self.features_key = features_key
25        self.labels_key = labels_key
26        self.transform = transform
27
28        # Open file to get dataset info (don't load data)
29        with h5py.File(h5_path, 'r') as f:
30            self.length = len(f[features_key])
31            self.feature_shape = f[features_key].shape[1:]
32            print(f"HDF5 dataset: {self.length} samples, "
33                  f"shape {self.feature_shape}")
34
35        # Keep file handle open for efficiency
36        # (but be careful with multiprocessing!)
37        self._h5_file = None
38
39    def _get_file(self):
40        """Lazy open file handle (worker-safe)."""
41        if self._h5_file is None:
42            self._h5_file = h5py.File(self.h5_path, 'r')
43        return self._h5_file
44
45    def __len__(self) -> int:
46        return self.length
47
48    def __getitem__(self, idx: int):
49        f = self._get_file()
50
51        # HDF5 supports efficient single-index access
52        features = f[self.features_key][idx]
53        labels = f[self.labels_key][idx]
54
55        # Convert to tensor
56        features = torch.from_numpy(features.astype(np.float32))
57        labels = torch.tensor(labels, dtype=torch.long)
58
59        if self.transform:
60            features = self.transform(features)
61
62        return features, labels
63
64    def close(self):
65        """Explicitly close file handle."""
66        if self._h5_file is not None:
67            self._h5_file.close()
68            self._h5_file = None
69
70
71# Creating an HDF5 file from numpy arrays
72def create_hdf5_dataset(
73    output_path: str,
74    features: np.ndarray,
75    labels: np.ndarray,
76    compression: str = 'gzip',
77    compression_opts: int = 4,
78):
79    """Save numpy arrays to HDF5 with compression."""
80    with h5py.File(output_path, 'w') as f:
81        f.create_dataset(
82            'features',
83            data=features,
84            compression=compression,
85            compression_opts=compression_opts,
86            chunks=True  # Enable chunking for partial reads
87        )
88        f.create_dataset(
89            'labels',
90            data=labels,
91            compression=compression,
92            compression_opts=compression_opts,
93        )
94    print(f"Saved {len(features)} samples to {output_path}")
95
96
97# Usage
98# First, create the HDF5 file (one-time preprocessing)
99features = np.random.randn(100000, 3, 224, 224).astype(np.float32)
100labels = np.random.randint(0, 1000, size=100000)
101create_hdf5_dataset("large_dataset.h5", features, labels)
102
103# Then use in training (every time)
104dataset = HDF5Dataset("large_dataset.h5")
105print(f"Dataset size: {len(dataset)}")

HDF5 and DataLoader Workers

HDF5 file handles don't transfer across process boundaries. With num_workers>0, each worker must open its own file handle. The _get_file() pattern handles this by opening lazily in each worker process.

Quick Check

Why does the HDF5Dataset open the file lazily in __getitem__ rather than in __init__?


Remote and Cloud Data

Training data often lives in cloud storage (S3, GCS, Azure Blob) or behind HTTP APIs. Here's how to build datasets that access remote data efficiently.

S3 Dataset with Smart Caching

🐍s3_dataset.py
1import torch
2from torch.utils.data import Dataset
3import boto3
4from PIL import Image
5import io
6from pathlib import Path
7import hashlib
8from typing import Optional
9
10class S3ImageDataset(Dataset):
11    """Dataset that loads images from Amazon S3 with local caching.
12
13    Caching strategy:
14    - First access: download from S3, save to local cache
15    - Subsequent accesses: load from local cache
16    - Cache persists across training runs
17    """
18
19    def __init__(
20        self,
21        bucket: str,
22        prefix: str,
23        labels: dict,  # {s3_key: label}
24        cache_dir: Optional[str] = None,
25        transform=None,
26    ):
27        self.bucket = bucket
28        self.labels = labels
29        self.keys = list(labels.keys())
30        self.transform = transform
31
32        # Set up S3 client
33        self.s3 = boto3.client('s3')
34
35        # Set up local cache
36        if cache_dir:
37            self.cache_dir = Path(cache_dir)
38            self.cache_dir.mkdir(parents=True, exist_ok=True)
39        else:
40            self.cache_dir = None
41
42        print(f"S3 dataset: {len(self.keys)} images from s3://{bucket}/{prefix}")
43
44    def __len__(self) -> int:
45        return len(self.keys)
46
47    def _get_cache_path(self, s3_key: str) -> Path:
48        """Generate cache path from S3 key."""
49        # Hash the key to handle special characters and long paths
50        key_hash = hashlib.md5(s3_key.encode()).hexdigest()
51        extension = Path(s3_key).suffix
52        return self.cache_dir / f"{key_hash}{extension}"
53
54    def __getitem__(self, idx: int):
55        s3_key = self.keys[idx]
56        label = self.labels[s3_key]
57
58        # Try local cache first
59        if self.cache_dir:
60            cache_path = self._get_cache_path(s3_key)
61            if cache_path.exists():
62                image = Image.open(cache_path).convert('RGB')
63            else:
64                # Download from S3
65                image = self._download_from_s3(s3_key)
66                # Save to cache
67                image.save(cache_path)
68        else:
69            # No caching, always download
70            image = self._download_from_s3(s3_key)
71
72        if self.transform:
73            image = self.transform(image)
74
75        return image, torch.tensor(label, dtype=torch.long)
76
77    def _download_from_s3(self, s3_key: str) -> Image.Image:
78        """Download image from S3 into memory."""
79        response = self.s3.get_object(Bucket=self.bucket, Key=s3_key)
80        image_bytes = response['Body'].read()
81        return Image.open(io.BytesIO(image_bytes)).convert('RGB')
82
83
84# For truly streaming (no local storage), use:
85class S3StreamingDataset(Dataset):
86    """Dataset that streams directly from S3 without caching.
87
88    Use when: cache storage is limited, data changes frequently.
89    """
90
91    def __init__(self, bucket: str, keys: list, labels: list, transform=None):
92        self.bucket = bucket
93        self.keys = keys
94        self.labels = labels
95        self.transform = transform
96        self.s3 = None  # Lazy initialization for multiprocessing
97
98    def _get_s3_client(self):
99        if self.s3 is None:
100            self.s3 = boto3.client('s3')
101        return self.s3
102
103    def __len__(self):
104        return len(self.keys)
105
106    def __getitem__(self, idx):
107        s3 = self._get_s3_client()
108        response = s3.get_object(Bucket=self.bucket, Key=self.keys[idx])
109        image_bytes = response['Body'].read()
110        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
111
112        if self.transform:
113            image = self.transform(image)
114
115        return image, torch.tensor(self.labels[idx])

HTTP-Based Dataset

🐍http_dataset.py
1import torch
2from torch.utils.data import Dataset
3import requests
4from PIL import Image
5import io
6from typing import List, Dict
7import time
8
9class HTTPImageDataset(Dataset):
10    """Dataset that fetches images from HTTP URLs.
11
12    Includes retry logic and rate limiting for robustness.
13    """
14
15    def __init__(
16        self,
17        urls: List[str],
18        labels: List[int],
19        transform=None,
20        max_retries: int = 3,
21        retry_delay: float = 1.0,
22        timeout: float = 10.0,
23    ):
24        self.urls = urls
25        self.labels = labels
26        self.transform = transform
27        self.max_retries = max_retries
28        self.retry_delay = retry_delay
29        self.timeout = timeout
30
31        # Create session for connection pooling
32        self.session = requests.Session()
33
34    def __len__(self):
35        return len(self.urls)
36
37    def __getitem__(self, idx):
38        url = self.urls[idx]
39        label = self.labels[idx]
40
41        # Retry loop
42        for attempt in range(self.max_retries):
43            try:
44                response = self.session.get(url, timeout=self.timeout)
45                response.raise_for_status()
46
47                image = Image.open(io.BytesIO(response.content)).convert('RGB')
48
49                if self.transform:
50                    image = self.transform(image)
51
52                return image, torch.tensor(label, dtype=torch.long)
53
54            except Exception as e:
55                if attempt < self.max_retries - 1:
56                    time.sleep(self.retry_delay * (attempt + 1))
57                else:
58                    # Return placeholder on final failure
59                    print(f"Failed to load {url}: {e}")
60                    placeholder = Image.new('RGB', (224, 224), color='gray')
61                    if self.transform:
62                        placeholder = self.transform(placeholder)
63                    return placeholder, torch.tensor(label)

Network Dataset Best Practices

  • Always cache when possible—network I/O is 100-1000x slower than disk
  • Use connection pooling (requests.Session) to reuse TCP connections
  • Implement retries with exponential backoff for transient failures
  • Return placeholders on failure rather than crashing the training loop

Interactive: Data Pipeline

The visualization below shows how custom datasets integrate into the PyTorch data pipeline. Watch how your Dataset provides samples to the DataLoader, which handles batching, shuffling, and parallel loading.

Data Pipeline Visualizer
Raw Dataset
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 samples, 5 per class
Shuffled Order
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Randomly shuffled
Current Batch
Batch size: 4 | Shape: [4, features]
Processed Batches
0/5 batches
Epoch Progress: 0/5 batches
Total samples seen: 0 / 20

Key observations:

  • Dataset provides indexed access to individual samples
  • DataLoader orchestrates batching and shuffling
  • Worker processes load data in parallel to keep the GPU busy
  • Your custom __getitem__ runs in worker processes, not the main process

Error Handling and Debugging

Real datasets have corrupted files, missing data, and edge cases. Robust error handling is essential.

Graceful Error Handling

Robust Dataset with Error Handling
🐍robust_dataset.py
55Force Image Load

image.load() forces PIL to actually read the file. Without this, PIL may defer loading and errors won't appear until later (e.g., in transforms).

62Specific Exception Types

Catching specific exceptions allows different handling strategies. FileNotFoundError might warrant a warning, while PermissionError might need admin attention.

75Transform Error Handling

Transforms can fail too (e.g., division by zero in normalization). Catching transform errors ensures the entire sample is handled gracefully.

83Error Tracking

Storing error indices and messages enables post-training analysis. You can identify problematic files and fix them for future runs.

93Configurable Strictness

return_placeholder_on_error=True keeps training running; False crashes immediately. Use True for training, False for debugging.

112 lines without explanation
1import torch
2from torch.utils.data import Dataset
3from PIL import Image
4import logging
5from pathlib import Path
6from typing import Optional, Tuple
7
8# Set up logging
9logging.basicConfig(level=logging.WARNING)
10logger = logging.getLogger(__name__)
11
12class RobustImageDataset(Dataset):
13    """Image dataset with comprehensive error handling.
14
15    Handles:
16    - Corrupted image files
17    - Missing files
18    - Unsupported formats
19    - Permission errors
20    """
21
22    def __init__(
23        self,
24        image_paths: list,
25        labels: list,
26        transform=None,
27        return_placeholder_on_error: bool = True,
28        placeholder_size: Tuple[int, int] = (224, 224),
29        log_errors: bool = True,
30    ):
31        self.image_paths = [Path(p) for p in image_paths]
32        self.labels = labels
33        self.transform = transform
34        self.return_placeholder = return_placeholder_on_error
35        self.placeholder_size = placeholder_size
36        self.log_errors = log_errors
37
38        # Track errors for later analysis
39        self.error_indices = []
40        self.error_messages = []
41
42        # Validate on init (optional)
43        self._validate_paths()
44
45    def _validate_paths(self):
46        """Check that all paths exist (optional pre-validation)."""
47        missing = [p for p in self.image_paths if not p.exists()]
48        if missing:
49            logger.warning(f"{len(missing)} files not found")
50
51    def __len__(self) -> int:
52        return len(self.image_paths)
53
54    def __getitem__(self, idx: int):
55        path = self.image_paths[idx]
56        label = self.labels[idx]
57
58        try:
59            # Attempt to load image
60            image = Image.open(path)
61
62            # Force load to catch truncated images
63            image.load()
64
65            # Convert to RGB (handles grayscale, RGBA, etc.)
66            image = image.convert('RGB')
67
68        except FileNotFoundError:
69            return self._handle_error(idx, f"File not found: {path}")
70
71        except PermissionError:
72            return self._handle_error(idx, f"Permission denied: {path}")
73
74        except (IOError, OSError) as e:
75            return self._handle_error(idx, f"Corrupted image: {path} ({e})")
76
77        except Exception as e:
78            return self._handle_error(idx, f"Unexpected error: {path} ({e})")
79
80        # Apply transform
81        if self.transform:
82            try:
83                image = self.transform(image)
84            except Exception as e:
85                return self._handle_error(
86                    idx, f"Transform failed: {path} ({e})"
87                )
88
89        return image, torch.tensor(label, dtype=torch.long)
90
91    def _handle_error(self, idx: int, message: str):
92        """Handle loading errors gracefully."""
93        # Log the error
94        if self.log_errors:
95            logger.warning(message)
96
97        # Track for later analysis
98        self.error_indices.append(idx)
99        self.error_messages.append(message)
100
101        if self.return_placeholder:
102            # Return gray placeholder image
103            placeholder = Image.new('RGB', self.placeholder_size, color='gray')
104            if self.transform:
105                placeholder = self.transform(placeholder)
106            return placeholder, torch.tensor(self.labels[idx], dtype=torch.long)
107        else:
108            # Re-raise to crash training (strict mode)
109            raise RuntimeError(message)
110
111    def get_error_report(self) -> dict:
112        """Return summary of all errors encountered."""
113        return {
114            'total_errors': len(self.error_indices),
115            'error_rate': len(self.error_indices) / len(self) * 100,
116            'errors': list(zip(self.error_indices, self.error_messages))
117        }

Debugging Tips

🐍debugging_datasets.py
1# 1. Test a single sample first
2dataset = MyDataset(...)
3sample = dataset[0]
4print(f"Sample type: {type(sample)}")
5print(f"Sample shapes: {[x.shape if hasattr(x, 'shape') else x for x in sample]}")
6
7# 2. Test boundary conditions
8print(f"First sample: {dataset[0]}")
9print(f"Last sample: {dataset[len(dataset) - 1]}")
10
11# Try out-of-bounds (should raise IndexError)
12try:
13    dataset[len(dataset)]
14except IndexError:
15    print("IndexError raised correctly!")
16
17# 3. Test with small DataLoader
18from torch.utils.data import DataLoader
19
20loader = DataLoader(dataset, batch_size=4, num_workers=0)
21batch = next(iter(loader))
22print(f"Batch shapes: {[x.shape for x in batch]}")
23
24# 4. Test with multiple workers (catches multiprocessing issues)
25loader = DataLoader(dataset, batch_size=4, num_workers=2)
26for i, batch in enumerate(loader):
27    if i >= 3:
28        break
29    print(f"Batch {i} loaded successfully")
30
31# 5. Profile loading time
32import time
33times = []
34for i in range(100):
35    start = time.time()
36    _ = dataset[i]
37    times.append(time.time() - start)
38
39print(f"Mean load time: {sum(times)/len(times)*1000:.2f}ms")
40print(f"Max load time: {max(times)*1000:.2f}ms")

Testing Your Datasets

Datasets are critical infrastructure. Test them thoroughly:

🐍test_dataset.py
1import pytest
2import torch
3from my_dataset import MyCustomDataset
4
5class TestMyCustomDataset:
6    """Unit tests for custom dataset."""
7
8    @pytest.fixture
9    def dataset(self):
10        """Create dataset for testing."""
11        return MyCustomDataset(test_data_path, transform=None)
12
13    def test_length(self, dataset):
14        """Dataset reports correct length."""
15        assert len(dataset) > 0
16        assert len(dataset) == expected_length
17
18    def test_getitem_returns_tuple(self, dataset):
19        """__getitem__ returns (features, label) tuple."""
20        sample = dataset[0]
21        assert isinstance(sample, tuple)
22        assert len(sample) == 2
23
24    def test_feature_shape(self, dataset):
25        """Features have expected shape."""
26        features, _ = dataset[0]
27        assert features.shape == expected_feature_shape
28
29    def test_feature_dtype(self, dataset):
30        """Features are float tensors."""
31        features, _ = dataset[0]
32        assert features.dtype == torch.float32
33
34    def test_label_dtype(self, dataset):
35        """Labels are integer tensors."""
36        _, label = dataset[0]
37        assert label.dtype == torch.long
38
39    def test_label_range(self, dataset):
40        """Labels are within valid range."""
41        for i in range(min(100, len(dataset))):
42            _, label = dataset[i]
43            assert 0 <= label < num_classes
44
45    def test_reproducibility(self, dataset):
46        """Same index returns same sample (no side effects)."""
47        sample1 = dataset[0]
48        sample2 = dataset[0]
49        assert torch.equal(sample1[0], sample2[0])
50        assert sample1[1] == sample2[1]
51
52    def test_boundary_indices(self, dataset):
53        """First and last indices work correctly."""
54        _ = dataset[0]
55        _ = dataset[len(dataset) - 1]
56
57    def test_out_of_bounds(self, dataset):
58        """Out-of-bounds index raises IndexError."""
59        with pytest.raises(IndexError):
60            _ = dataset[len(dataset)]
61
62    def test_negative_index(self, dataset):
63        """Negative indices work (Python convention)."""
64        sample = dataset[-1]
65        last_sample = dataset[len(dataset) - 1]
66        assert torch.equal(sample[0], last_sample[0])
67
68    def test_with_dataloader(self, dataset):
69        """Dataset works with DataLoader."""
70        from torch.utils.data import DataLoader
71        loader = DataLoader(dataset, batch_size=4, num_workers=0)
72        batch = next(iter(loader))
73        assert batch[0].shape[0] == 4  # Batch size
74
75    def test_with_multiprocessing(self, dataset):
76        """Dataset works with multiple workers."""
77        from torch.utils.data import DataLoader
78        loader = DataLoader(dataset, batch_size=4, num_workers=2)
79        # Load a few batches to verify multiprocessing works
80        for i, batch in enumerate(loader):
81            if i >= 3:
82                break

Test Data Fixtures

Create small test fixtures (10-100 samples) with known properties. This makes tests fast and deterministic. Don't test against your full training set!

Performance Optimization

Data loading can bottleneck training. Here are optimization strategies:

Optimization Strategies

StrategyWhen to UseSpeedup
num_workers > 0Always (CPU not saturated)2-8x
pin_memory=TrueUsing GPU1.1-1.3x
prefetch_factor > 2Slow __getitem__Variable
SSD storageLazy loading from disk5-20x vs HDD
Memory mappingLarge preprocessed dataNear-RAM speed
Pre-cachingSmall enough to cache10-100x

Optimization Example

🐍optimized_loading.py
1from torch.utils.data import DataLoader
2
3# Baseline: Single-threaded, no optimization
4slow_loader = DataLoader(dataset, batch_size=32)
5
6# Optimized: Multi-worker, pinned memory, prefetching
7fast_loader = DataLoader(
8    dataset,
9    batch_size=32,
10    shuffle=True,
11    num_workers=4,          # Parallel data loading
12    pin_memory=True,        # Faster GPU transfer
13    prefetch_factor=2,      # Prefetch 2 batches per worker
14    persistent_workers=True # Don't respawn workers each epoch
15)
16
17# For GPU training, monitor utilization
18# If GPU utilization < 90%, data loading is likely the bottleneck
19
20# Profiling data loading
21import time
22
23def benchmark_loader(loader, num_batches=100):
24    start = time.time()
25    for i, batch in enumerate(loader):
26        if i >= num_batches:
27            break
28        # Simulate model forward pass
29        _ = batch[0].cuda()
30    elapsed = time.time() - start
31    throughput = num_batches * loader.batch_size / elapsed
32    print(f"Throughput: {throughput:.1f} samples/sec")
33
34benchmark_loader(slow_loader)
35benchmark_loader(fast_loader)

Memory-Efficient Caching

🐍cached_dataset.py
1from torch.utils.data import Dataset
2from functools import lru_cache
3import torch
4
5class CachedDataset(Dataset):
6    """Dataset with LRU caching for frequently accessed samples."""
7
8    def __init__(self, base_dataset, cache_size=10000):
9        self.base = base_dataset
10        self.cache_size = cache_size
11
12    def __len__(self):
13        return len(self.base)
14
15    @lru_cache(maxsize=10000)
16    def __getitem__(self, idx):
17        # Warning: This caches transformed data!
18        # Don't use with random augmentations
19        return self.base[idx]
20
21    def clear_cache(self):
22        """Clear cache between epochs if using augmentation."""
23        self.__getitem__.cache_clear()
24
25
26# For random augmentation, cache raw data only
27class SmartCachedDataset(Dataset):
28    """Cache raw data, apply transforms on-the-fly."""
29
30    def __init__(self, base_dataset, cache_size=10000):
31        self.base = base_dataset
32        self.transform = base_dataset.transform
33        self.base.transform = None  # Disable transform in base
34        self.cache = {}
35        self.cache_size = cache_size
36
37    def __len__(self):
38        return len(self.base)
39
40    def __getitem__(self, idx):
41        if idx not in self.cache:
42            if len(self.cache) >= self.cache_size:
43                # Simple eviction: remove random item
44                self.cache.pop(next(iter(self.cache)))
45            self.cache[idx] = self.base[idx]
46
47        sample = self.cache[idx]
48
49        # Apply transform (with randomness) each time
50        if self.transform:
51            features, label = sample
52            features = self.transform(features)
53            return features, label
54        return sample

Caching and Augmentation

Never cache after applying random augmentations! The model would see the same augmented version every epoch, defeating the purpose. Either cache raw data and augment on-the-fly, or clear the cache each epoch.

Common Patterns and Recipes

Pattern 1: Preprocessing Pipeline

🐍preprocessing_pipeline.py
1# Step 1: Preprocess raw data to optimized format (run once)
2def preprocess_dataset(raw_dir, output_path):
3    """Convert raw images to preprocessed HDF5."""
4    import h5py
5    from PIL import Image
6    from pathlib import Path
7
8    images = list(Path(raw_dir).glob("**/*.jpg"))
9
10    with h5py.File(output_path, 'w') as f:
11        features = f.create_dataset(
12            'features',
13            shape=(len(images), 3, 224, 224),
14            dtype='float32',
15            compression='gzip'
16        )
17
18        for i, img_path in enumerate(images):
19            img = Image.open(img_path).convert('RGB')
20            img = img.resize((224, 224))
21            arr = np.array(img).transpose(2, 0, 1) / 255.0
22            features[i] = arr
23
24            if i % 1000 == 0:
25                print(f"Processed {i}/{len(images)}")
26
27# Step 2: Fast training with preprocessed data
28dataset = HDF5Dataset("preprocessed.h5")

Pattern 2: Multi-GPU Dataset Sharding

🐍sharded_dataset.py
1from torch.utils.data import Dataset, DistributedSampler
2
3class ShardedDataset(Dataset):
4    """Dataset that loads a specific shard for distributed training."""
5
6    def __init__(self, shard_paths: list, shard_idx: int, num_shards: int):
7        # Each GPU loads different shards
8        self.my_shards = shard_paths[shard_idx::num_shards]
9        self.samples = self._load_shards()
10
11    def _load_shards(self):
12        samples = []
13        for shard_path in self.my_shards:
14            samples.extend(load_shard(shard_path))
15        return samples
16
17    def __len__(self):
18        return len(self.samples)
19
20    def __getitem__(self, idx):
21        return self.samples[idx]
22
23
24# Alternative: Use DistributedSampler (simpler)
25from torch.utils.data.distributed import DistributedSampler
26
27dataset = MyDataset(...)
28sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
29loader = DataLoader(dataset, sampler=sampler, batch_size=32)

Pattern 3: On-the-Fly Feature Engineering

🐍feature_engineering_dataset.py
1class FeatureEngineeringDataset(Dataset):
2    """Dataset with on-the-fly feature engineering."""
3
4    def __init__(self, df, feature_columns, label_column):
5        self.df = df
6        self.feature_columns = feature_columns
7        self.label_column = label_column
8
9    def __len__(self):
10        return len(self.df)
11
12    def __getitem__(self, idx):
13        row = self.df.iloc[idx]
14
15        # Extract base features
16        features = row[self.feature_columns].values
17
18        # Engineer new features on-the-fly
19        engineered = [
20            row['feature_a'] * row['feature_b'],  # Interaction
21            np.log1p(row['feature_c']),           # Log transform
22            row['feature_d'] ** 2,                # Polynomial
23            (row['feature_e'] - self.mean_e) / self.std_e,  # Normalize
24        ]
25
26        all_features = np.concatenate([features, engineered])
27
28        return (
29            torch.tensor(all_features, dtype=torch.float32),
30            torch.tensor(row[self.label_column], dtype=torch.long)
31        )

Summary

Custom datasets are the bridge between your data and PyTorch's training infrastructure. They adapt any data format to the standard Dataset interface.

Key Concepts

ConceptKey InsightImplementation
Dataset ContractAnswer 'how many?' and 'what is i?'__len__ and __getitem__
Lazy LoadingLoad only when accessedStore paths in __init__, load in __getitem__
Multi-ModalSynchronize modalitiesReturn dict with all modalities
Error HandlingGraceful degradationTry-except with placeholders
PerformanceData loading often bottlenecksWorkers, caching, preprocessing

Best Practices Checklist

  1. Keep __init__ lightweight: Only store paths and metadata
  2. Lazy load in __getitem__: Load data only when accessed
  3. Handle errors gracefully: Don't crash on one bad file
  4. Test with DataLoader: Verify multiprocessing works
  5. Profile before optimizing: Find the actual bottleneck
  6. Preprocess large datasets: Convert to efficient formats (HDF5, Zarr)
  7. Cache when appropriate: But never with random augmentation
  8. Document the data format: Future you will thank present you

Looking Ahead

In the next section, we'll tackle Handling Imbalanced Data—techniques like weighted sampling, oversampling, and class-balanced losses that ensure minority classes get fair treatment during training.


Exercises

Conceptual Questions

  1. Explain why opening a file handle in __init__ doesn't work with DataLoader's num_workers > 0. What is the correct pattern?
  2. A colleague proposes loading all images into RAM in __init__ "for speed." Under what conditions is this a good idea? When is it problematic?
  3. Why should random augmentations NOT be applied before caching? How would you design a caching strategy that still allows augmentation?
  4. Compare the tradeoffs between storing data as individual files vs. a single HDF5 file. When would you prefer each approach?

Coding Exercises

  1. Audio Dataset: Implement a dataset for audio classification that:
    • Loads .wav files lazily
    • Resamples to a common sample rate (16kHz)
    • Pads or truncates to fixed length (5 seconds)
    • Returns spectrogram as features
  2. Database Dataset: Create a dataset that:
    • Loads records from a SQLite database
    • Supports lazy loading with efficient indexing
    • Handles NULL values gracefully
  3. Streaming Dataset: Implement an IterableDataset that:
    • Reads from a Kafka topic or similar stream
    • Buffers messages for batching
    • Handles reconnection on failure

Solution Hints

  • Audio: Use torchaudio.load() and torchaudio.transforms.Resample()
  • Database: Use sqlite3 with SELECT ... LIMIT 1 OFFSET idx
  • Streaming: Inherit from IterableDataset, implement __iter__

Challenge Exercise

Self-Supervised Contrastive Dataset: Build a dataset for SimCLR-style contrastive learning:

  • Each __getitem__ returns TWO augmented views of the same image
  • The two views use different random augmentations
  • Include a special collate function that creates positive pairs
  • Ensure no data leakage between the two views

This exercise teaches you to think about how the dataset structure supports the learning objective (contrastive learning requires paired samples).


Next, we'll learn to handle Imbalanced Data—ensuring your model doesn't ignore rare but important classes.