Introduction
Efficient data loading is crucial for training. This section covers how to create PyTorch datasets and dataloaders for translation, including dynamic batching, padding, and handling variable-length sequences.
Translation Dataset Structure
Parallel Corpus Format
πtext
1Translation data typically comes as parallel sentences:
2
3Source (German): Target (English):
4βββββββββββββββββ βββββββββββββββββ
5"Der Hund lΓ€uft" "The dog runs"
6"Die Katze schlΓ€ft" "The cat sleeps"
7"Ich liebe Programmieren" "I love programming"
8
9Each line i in source file corresponds to line i in target file.Multi30k Dataset
πtext
1Multi30k: Standard dataset for German-English translation research
2- ~30,000 sentence pairs
3- Short sentences (average ~12 words)
4- Image captions in multiple languages
5- Commonly used for benchmarking
6
7Files:
8 train.de, train.en
9 val.de, val.en
10 test.de, test.enTranslation Dataset Class
Basic Implementation
πpython
1import torch
2from torch.utils.data import Dataset, DataLoader
3from typing import List, Tuple, Optional, Dict
4import os
5
6
7class TranslationDataset(Dataset):
8 """
9 Dataset for machine translation.
10
11 Loads parallel source-target sentence pairs and tokenizes them.
12
13 Args:
14 source_file: Path to source language file
15 target_file: Path to target language file
16 tokenizer: Tokenizer with encode_source and encode_target methods
17 max_source_len: Maximum source sequence length
18 max_target_len: Maximum target sequence length
19
20 Example:
21 >>> dataset = TranslationDataset(
22 ... 'data/train.de', 'data/train.en',
23 ... tokenizer, max_source_len=100, max_target_len=100
24 ... )
25 >>> src, tgt = dataset[0]
26 """
27
28 def __init__(
29 self,
30 source_file: str,
31 target_file: str,
32 tokenizer,
33 max_source_len: int = 128,
34 max_target_len: int = 128
35 ):
36 self.tokenizer = tokenizer
37 self.max_source_len = max_source_len
38 self.max_target_len = max_target_len
39
40 # Load sentence pairs
41 self.source_sentences = self._load_file(source_file)
42 self.target_sentences = self._load_file(target_file)
43
44 assert len(self.source_sentences) == len(self.target_sentences), \
45 "Source and target must have same number of sentences"
46
47 print(f"Loaded {len(self)} sentence pairs")
48
49 def _load_file(self, path: str) -> List[str]:
50 """Load sentences from file."""
51 with open(path, 'r', encoding='utf-8') as f:
52 sentences = [line.strip() for line in f]
53 return sentences
54
55 def __len__(self) -> int:
56 return len(self.source_sentences)
57
58 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
59 """
60 Get a tokenized source-target pair.
61
62 Returns:
63 source_ids: Tokenized source [src_len]
64 target_ids: Tokenized target with BOS/EOS [tgt_len]
65 """
66 source = self.source_sentences[idx]
67 target = self.target_sentences[idx]
68
69 # Tokenize source (no special tokens needed for encoder)
70 source_ids = self.tokenizer.encode_source(
71 source,
72 max_length=self.max_source_len
73 )
74
75 # Tokenize target with BOS and EOS
76 target_ids = self.tokenizer.encode_target(
77 target,
78 add_bos=True,
79 add_eos=True,
80 max_length=self.max_target_len
81 )
82
83 return (
84 torch.tensor(source_ids, dtype=torch.long),
85 torch.tensor(target_ids, dtype=torch.long)
86 )Collation Function
Padding and Batching
πpython
1class TranslationCollator:
2 """
3 Collate function for translation batches.
4
5 Pads sequences to the same length within each batch.
6
7 Args:
8 pad_id: Padding token ID
9 batch_first: Whether to return [batch, seq] (True) or [seq, batch]
10 """
11
12 def __init__(self, pad_id: int = 0, batch_first: bool = True):
13 self.pad_id = pad_id
14 self.batch_first = batch_first
15
16 def __call__(
17 self,
18 batch: List[Tuple[torch.Tensor, torch.Tensor]]
19 ) -> Dict[str, torch.Tensor]:
20 """
21 Collate a batch of source-target pairs.
22
23 Args:
24 batch: List of (source, target) tensor pairs
25
26 Returns:
27 Dictionary with:
28 - source_ids: [batch, max_src_len]
29 - target_ids: [batch, max_tgt_len]
30 - source_lengths: [batch]
31 - target_lengths: [batch]
32 """
33 sources, targets = zip(*batch)
34
35 # Get lengths
36 source_lengths = torch.tensor([len(s) for s in sources])
37 target_lengths = torch.tensor([len(t) for t in targets])
38
39 # Pad sequences
40 source_ids = self._pad_sequence(sources)
41 target_ids = self._pad_sequence(targets)
42
43 return {
44 'source_ids': source_ids,
45 'target_ids': target_ids,
46 'source_lengths': source_lengths,
47 'target_lengths': target_lengths
48 }
49
50 def _pad_sequence(
51 self,
52 sequences: Tuple[torch.Tensor, ...]
53 ) -> torch.Tensor:
54 """Pad sequences to same length."""
55 max_len = max(len(s) for s in sequences)
56 batch_size = len(sequences)
57
58 padded = torch.full(
59 (batch_size, max_len),
60 self.pad_id,
61 dtype=sequences[0].dtype
62 )
63
64 for i, seq in enumerate(sequences):
65 padded[i, :len(seq)] = seq
66
67 return paddedDynamic Batching by Tokens
Efficient Batching
πpython
1class TokenBucketBatcher:
2 """
3 Create batches with similar total token count.
4
5 Instead of fixed batch size, dynamically groups sequences
6 so each batch has approximately the same number of tokens.
7 This is more memory-efficient than fixed batch size.
8
9 Args:
10 dataset: Translation dataset
11 max_tokens: Maximum tokens per batch
12 drop_last: Whether to drop incomplete last batch
13 """
14
15 def __init__(
16 self,
17 dataset: TranslationDataset,
18 max_tokens: int = 4096,
19 drop_last: bool = False
20 ):
21 self.dataset = dataset
22 self.max_tokens = max_tokens
23 self.drop_last = drop_last
24
25 # Pre-compute lengths for bucketing
26 self.lengths = self._compute_lengths()
27 self.batches = self._create_batches()
28
29 def _compute_lengths(self) -> List[Tuple[int, int, int]]:
30 """Compute (idx, src_len, tgt_len) for all samples."""
31 lengths = []
32 for idx in range(len(self.dataset)):
33 src, tgt = self.dataset[idx]
34 lengths.append((idx, len(src), len(tgt)))
35 return lengths
36
37 def _create_batches(self) -> List[List[int]]:
38 """Create batches with similar token counts."""
39 # Sort by source length for efficient packing
40 sorted_indices = sorted(
41 self.lengths,
42 key=lambda x: (x[1], x[2]) # Sort by src_len, then tgt_len
43 )
44
45 batches = []
46 current_batch = []
47 current_tokens = 0
48
49 for idx, src_len, tgt_len in sorted_indices:
50 # Tokens for this sample: max of padded length in batch
51 sample_tokens = src_len + tgt_len
52
53 # Check if adding this sample would exceed limit
54 if current_batch:
55 new_max_src = max(
56 max(self.lengths[i][1] for i in current_batch),
57 src_len
58 )
59 new_max_tgt = max(
60 max(self.lengths[i][2] for i in current_batch),
61 tgt_len
62 )
63 estimated_tokens = (len(current_batch) + 1) * (new_max_src + new_max_tgt)
64
65 if estimated_tokens > self.max_tokens:
66 # Start new batch
67 batches.append(current_batch)
68 current_batch = []
69
70 current_batch.append(idx)
71
72 # Don't forget last batch
73 if current_batch and (not self.drop_last or len(current_batch) >= 1):
74 batches.append(current_batch)
75
76 return batches
77
78 def __len__(self) -> int:
79 return len(self.batches)
80
81 def __iter__(self):
82 """Yield batches."""
83 import random
84 indices = list(range(len(self.batches)))
85 random.shuffle(indices)
86
87 for batch_idx in indices:
88 batch_indices = self.batches[batch_idx]
89 yield [self.dataset[i] for i in batch_indices]Visualization
πtext
1Fixed Batch Size (N=4):
2βββββββββββββββββββββββ
3Batch 1: [Short, Short, Short, Short] β Mostly padding
4Batch 2: [Long, Long, Long, Long] β Memory overflow!
5
6ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
7βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ Seq 1
8βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ Seq 2
9βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ Seq 3
10βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ Seq 4 (long)
11ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
12Wasted compute: βββ (lots of padding!)
13
14
15Dynamic Batching (max_tokens=100):
16ββββββββββββββββββββββββββββββββββ
17Batch 1: [Short Γ 10] β ~100 tokens total
18Batch 2: [Medium Γ 5] β ~100 tokens total
19Batch 3: [Long Γ 2] β ~100 tokens total
20
21ββββββββββββββββββββββββ
22βββββββββ Seq 1 β
23βββββββββ Seq 2 β
24βββββββββ Seq 3 β
25βββββββββ Seq 4 β ~100 tokens
26βββββββββ Seq 5 β
27ββββββββββββββββββββββββ
28
29βββββββββββββββββββββββββββββββββββ
30ββββββββββββββββββ Seq 6 β
31ββββββββββββββββββ Seq 7 β ~100 tokens
32ββββββββββββββββββ Seq 8 β
33βββββββββββββββββββββββββββββββββββ
34
35Much less wasted compute!Complete DataLoader Setup
Production Implementation
πpython
1def create_translation_dataloaders(
2 train_source: str,
3 train_target: str,
4 val_source: str,
5 val_target: str,
6 tokenizer,
7 batch_size: int = 32,
8 max_source_len: int = 128,
9 max_target_len: int = 128,
10 num_workers: int = 4,
11 pin_memory: bool = True
12) -> Tuple[DataLoader, DataLoader]:
13 """
14 Create training and validation dataloaders.
15
16 Args:
17 train_source: Path to training source file
18 train_target: Path to training target file
19 val_source: Path to validation source file
20 val_target: Path to validation target file
21 tokenizer: Tokenizer instance
22 batch_size: Batch size
23 max_source_len: Maximum source length
24 max_target_len: Maximum target length
25 num_workers: Number of data loading workers
26 pin_memory: Whether to pin memory for GPU transfer
27
28 Returns:
29 train_loader, val_loader
30 """
31 # Create datasets
32 train_dataset = TranslationDataset(
33 train_source, train_target,
34 tokenizer,
35 max_source_len, max_target_len
36 )
37
38 val_dataset = TranslationDataset(
39 val_source, val_target,
40 tokenizer,
41 max_source_len, max_target_len
42 )
43
44 # Create collator
45 collator = TranslationCollator(
46 pad_id=tokenizer.pad_token_id
47 )
48
49 # Create loaders
50 train_loader = DataLoader(
51 train_dataset,
52 batch_size=batch_size,
53 shuffle=True,
54 collate_fn=collator,
55 num_workers=num_workers,
56 pin_memory=pin_memory,
57 drop_last=True # For consistent batch sizes during training
58 )
59
60 val_loader = DataLoader(
61 val_dataset,
62 batch_size=batch_size,
63 shuffle=False,
64 collate_fn=collator,
65 num_workers=num_workers,
66 pin_memory=pin_memory
67 )
68
69 return train_loader, val_loader
70
71
72class TranslationDataModule:
73 """
74 Complete data module for translation training.
75
76 Encapsulates all data loading logic.
77 """
78
79 def __init__(
80 self,
81 data_dir: str,
82 tokenizer,
83 batch_size: int = 32,
84 max_source_len: int = 128,
85 max_target_len: int = 128,
86 num_workers: int = 4
87 ):
88 self.data_dir = data_dir
89 self.tokenizer = tokenizer
90 self.batch_size = batch_size
91 self.max_source_len = max_source_len
92 self.max_target_len = max_target_len
93 self.num_workers = num_workers
94
95 self.train_loader = None
96 self.val_loader = None
97 self.test_loader = None
98
99 def setup(self):
100 """Set up datasets and dataloaders."""
101 collator = TranslationCollator(
102 pad_id=self.tokenizer.pad_token_id
103 )
104
105 # Training
106 train_dataset = TranslationDataset(
107 os.path.join(self.data_dir, 'train.de'),
108 os.path.join(self.data_dir, 'train.en'),
109 self.tokenizer,
110 self.max_source_len,
111 self.max_target_len
112 )
113
114 self.train_loader = DataLoader(
115 train_dataset,
116 batch_size=self.batch_size,
117 shuffle=True,
118 collate_fn=collator,
119 num_workers=self.num_workers,
120 pin_memory=True
121 )
122
123 # Validation
124 val_dataset = TranslationDataset(
125 os.path.join(self.data_dir, 'val.de'),
126 os.path.join(self.data_dir, 'val.en'),
127 self.tokenizer,
128 self.max_source_len,
129 self.max_target_len
130 )
131
132 self.val_loader = DataLoader(
133 val_dataset,
134 batch_size=self.batch_size,
135 shuffle=False,
136 collate_fn=collator,
137 num_workers=self.num_workers,
138 pin_memory=True
139 )
140
141 print(f"Train batches: {len(self.train_loader)}")
142 print(f"Val batches: {len(self.val_loader)}")
143
144 def train_dataloader(self) -> DataLoader:
145 return self.train_loader
146
147 def val_dataloader(self) -> DataLoader:
148 return self.val_loaderSummary
Key Components
| Component | Purpose |
|---|---|
| TranslationDataset | Load and tokenize sentence pairs |
| TranslationCollator | Pad batches to same length |
| TokenBucketBatcher | Dynamic batching by tokens |
| TranslationDataModule | Encapsulate all data logic |
Best Practices
- Dynamic batching for memory efficiency
- Sort by length to minimize padding
- Pre-tokenize if possible for speed
- Pin memory for faster GPU transfer
- Multiple workers for parallel loading
Exercises
Implementation
- Add support for data augmentation (e.g., back-translation).
- Implement curriculum learning (start with short sentences).
- Add caching of tokenized data to disk.
Analysis
- Compare training speed with different batch sizes.
- Measure GPU utilization with different num_workers.
In the next section, we'll implement Label Smoothing Lossβa regularization technique that prevents the model from becoming overconfident and improves generalization.