This section creates the complete data loading pipeline for training. We'll build PyTorch datasets and dataloaders that efficiently prepare batches for our translation model.
4.1 Translation Dataset
PyTorch Dataset Implementation
πpython
1import torch
2from torch.utils.data import Dataset, DataLoader
3from typing import List, Dict, Optional, Tuple
4from pathlib import Path
5
6
7class TranslationDataset(Dataset):
8 """
9 PyTorch Dataset for machine translation.
10
11 Handles:
12 - Loading parallel text files
13 - Tokenization with special tokens
14 - Returning source-target pairs
15
16 Args:
17 source_path: Path to source language file
18 target_path: Path to target language file
19 tokenizer: Trained tokenizer instance
20 max_length: Maximum sequence length
21
22 Example:
23 >>> dataset = TranslationDataset('train.de', 'train.en', tokenizer)
24 >>> src, tgt = dataset[0]
25 """
26
27 def __init__(
28 self,
29 source_path: str,
30 target_path: str,
31 tokenizer,
32 max_length: int = 128
33 ):
34 self.tokenizer = tokenizer
35 self.max_length = max_length
36
37 # Load data
38 self.source_sentences = self._load_file(source_path)
39 self.target_sentences = self._load_file(target_path)
40
41 assert len(self.source_sentences) == len(self.target_sentences), \
42 "Source and target must have same number of sentences"
43
44 print(f"Loaded {len(self)} sentence pairs")
45
46 def _load_file(self, path: str) -> List[str]:
47 """Load sentences from file."""
48 with open(path, 'r', encoding='utf-8') as f:
49 return [line.strip() for line in f]
50
51 def __len__(self) -> int:
52 return len(self.source_sentences)
53
54 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
55 """
56 Get a single example.
57
58 Returns:
59 Dictionary with:
60 - source_ids: Source token IDs (no special tokens)
61 - target_ids: Target token IDs (with BOS, EOS)
62 - source_text: Original source text
63 - target_text: Original target text
64 """
65 source_text = self.source_sentences[idx]
66 target_text = self.target_sentences[idx]
67
68 # Tokenize source (no special tokens)
69 source_ids = self.tokenizer.encode(
70 source_text,
71 add_special_tokens=False
72 )
73
74 # Tokenize target (with BOS and EOS)
75 target_ids = self.tokenizer.encode(
76 target_text,
77 add_special_tokens=True
78 )
79
80 # Truncate if needed
81 if len(source_ids) > self.max_length:
82 source_ids = source_ids[:self.max_length]
83
84 if len(target_ids) > self.max_length:
85 target_ids = target_ids[:self.max_length - 1] + [self.tokenizer.eos_id]
86
87 return {
88 'source_ids': torch.tensor(source_ids, dtype=torch.long),
89 'target_ids': torch.tensor(target_ids, dtype=torch.long),
90 'source_text': source_text,
91 'target_text': target_text,
92 }4.2 Collate Function
Batching with Padding
πpython
1class TranslationCollator:
2 """
3 Collate function for translation batches.
4
5 Handles:
6 - Padding sequences to same length
7 - Creating attention masks
8 - Separating decoder input and labels
9
10 Args:
11 pad_id: Padding token ID
12 return_tensors: Return type ('pt' for PyTorch)
13 """
14
15 def __init__(self, pad_id: int = 0):
16 self.pad_id = pad_id
17
18 def __call__(
19 self,
20 batch: List[Dict[str, torch.Tensor]]
21 ) -> Dict[str, torch.Tensor]:
22 """
23 Collate a batch of examples.
24
25 Returns:
26 Dictionary with:
27 - source: Padded source IDs [batch, src_len]
28 - source_mask: Source attention mask [batch, src_len]
29 - target_input: Decoder input [batch, tgt_len-1]
30 - target_output: Labels [batch, tgt_len-1]
31 - target_mask: Target attention mask [batch, tgt_len-1]
32 """
33 # Extract sequences
34 source_ids = [item['source_ids'] for item in batch]
35 target_ids = [item['target_ids'] for item in batch]
36
37 # Pad source
38 source_padded = self._pad_sequence(source_ids)
39 source_mask = (source_padded != self.pad_id)
40
41 # Pad target
42 target_padded = self._pad_sequence(target_ids)
43
44 # Split target into input and output
45 # Input: [BOS, tok1, tok2, ...] (exclude last)
46 # Output: [tok1, tok2, ..., EOS] (exclude first)
47 target_input = target_padded[:, :-1]
48 target_output = target_padded[:, 1:]
49 target_mask = (target_input != self.pad_id)
50
51 return {
52 'source': source_padded,
53 'source_mask': source_mask,
54 'target_input': target_input,
55 'target_output': target_output,
56 'target_mask': target_mask,
57 }
58
59 def _pad_sequence(
60 self,
61 sequences: List[torch.Tensor]
62 ) -> torch.Tensor:
63 """Pad sequences to same length."""
64 max_len = max(len(seq) for seq in sequences)
65
66 padded = torch.full(
67 (len(sequences), max_len),
68 self.pad_id,
69 dtype=torch.long
70 )
71
72 for i, seq in enumerate(sequences):
73 padded[i, :len(seq)] = seq
74
75 return paddedExample test showing batch structure:
πpython
1collator = TranslationCollator(pad_id=0)
2
3# Create mock batch
4batch = [
5 {
6 'source_ids': torch.tensor([10, 20, 30]),
7 'target_ids': torch.tensor([2, 100, 200, 3]), # BOS...EOS
8 },
9 {
10 'source_ids': torch.tensor([10, 20, 30, 40, 50]),
11 'target_ids': torch.tensor([2, 100, 200, 300, 400, 3]),
12 },
13]
14
15result = collator(batch)
16
17print("\nBatch structure:")
18for key, value in result.items():
19 print(f" {key}: shape {value.shape}")
20
21# Output:
22# source: shape torch.Size([2, 5])
23# source_mask: shape torch.Size([2, 5])
24# target_input: shape torch.Size([2, 5])
25# target_output: shape torch.Size([2, 5])
26# target_mask: shape torch.Size([2, 5])4.3 Dynamic Batching
Batching by Token Count
πpython
1from torch.utils.data import Sampler
2import random
3
4
5class TokenBucketSampler(Sampler):
6 """
7 Sampler that creates batches based on total tokens.
8
9 Instead of fixed batch size, targets a maximum number
10 of tokens per batch. This ensures GPU memory efficiency.
11
12 Args:
13 dataset: Translation dataset
14 max_tokens: Maximum tokens per batch
15 shuffle: Whether to shuffle
16 seed: Random seed for reproducibility
17 """
18
19 def __init__(
20 self,
21 dataset: TranslationDataset,
22 max_tokens: int = 4096,
23 shuffle: bool = True,
24 seed: int = 42
25 ):
26 self.dataset = dataset
27 self.max_tokens = max_tokens
28 self.shuffle = shuffle
29 self.seed = seed
30 self.epoch = 0
31
32 # Pre-compute lengths
33 self.lengths = self._compute_lengths()
34
35 # Create batches
36 self.batches = self._create_batches()
37
38 def _compute_lengths(self) -> List[Tuple[int, int, int]]:
39 """Compute (index, src_len, tgt_len) for each example."""
40 lengths = []
41
42 for i in range(len(self.dataset)):
43 # Get lengths without loading full example
44 src_len = len(self.dataset.source_sentences[i].split())
45 tgt_len = len(self.dataset.target_sentences[i].split())
46 lengths.append((i, src_len, tgt_len))
47
48 return lengths
49
50 def _create_batches(self) -> List[List[int]]:
51 """Create batches based on token count."""
52 # Sort by length for efficient batching
53 sorted_lengths = sorted(
54 self.lengths,
55 key=lambda x: x[1] + x[2]
56 )
57
58 batches = []
59 current_batch = []
60 current_max_src = 0
61 current_max_tgt = 0
62
63 for idx, src_len, tgt_len in sorted_lengths:
64 # Estimate tokens with this example
65 new_max_src = max(current_max_src, src_len)
66 new_max_tgt = max(current_max_tgt, tgt_len)
67 batch_tokens = (len(current_batch) + 1) * (new_max_src + new_max_tgt)
68
69 if batch_tokens > self.max_tokens and current_batch:
70 # Start new batch
71 batches.append(current_batch)
72 current_batch = [idx]
73 current_max_src = src_len
74 current_max_tgt = tgt_len
75 else:
76 # Add to current batch
77 current_batch.append(idx)
78 current_max_src = new_max_src
79 current_max_tgt = new_max_tgt
80
81 # Don't forget last batch
82 if current_batch:
83 batches.append(current_batch)
84
85 return batches
86
87 def __iter__(self):
88 """Yield batch indices."""
89 if self.shuffle:
90 # Shuffle batches (not within batches)
91 random.seed(self.seed + self.epoch)
92 indices = list(range(len(self.batches)))
93 random.shuffle(indices)
94
95 for i in indices:
96 yield from self.batches[i]
97 else:
98 for batch in self.batches:
99 yield from batch
100
101 def __len__(self):
102 return len(self.dataset)
103
104 def set_epoch(self, epoch: int):
105 """Set epoch for shuffling."""
106 self.epoch = epochBenefits of Token Batching
πtext
1FIXED BATCH SIZE vs TOKEN BATCHING:
2ββββββββββββββββββββββββββββββββββββ
3
4Fixed batch size = 32:
5 Batch 1: 32 sentences of 5 words each = 160 tokens
6 Batch 2: 32 sentences of 50 words each = 1600 tokens
7 β Memory usage varies by 10x!
8
9Token batching (max_tokens=1000):
10 Batch 1: 200 short sentences = ~1000 tokens
11 Batch 2: 20 long sentences = ~1000 tokens
12 β Consistent memory usage!
13
14
15BENEFITS:
16βββββββββ
17
181. GPU memory efficiency
19 - Avoid OOM on long sequences
20 - Better utilization on short sequences
21
222. Training stability
23 - Consistent gradient scale
24 - More predictable training
25
263. Speed
27 - Less padding waste
28 - Better parallelism4.4 Complete DataLoader Setup
DataLoader Factory
πpython
1from torch.utils.data import DataLoader
2
3
4class DataLoaderFactory:
5 """
6 Factory for creating translation dataloaders.
7
8 Creates train, validation, and test dataloaders with
9 appropriate settings for each.
10
11 Args:
12 tokenizer: Trained tokenizer
13 data_dir: Directory containing data files
14 max_tokens: Maximum tokens per batch
15 max_length: Maximum sequence length
16 num_workers: Number of data loading workers
17 """
18
19 def __init__(
20 self,
21 tokenizer,
22 data_dir: str,
23 max_tokens: int = 4096,
24 max_length: int = 128,
25 num_workers: int = 2
26 ):
27 self.tokenizer = tokenizer
28 self.data_dir = Path(data_dir)
29 self.max_tokens = max_tokens
30 self.max_length = max_length
31 self.num_workers = num_workers
32
33 # Create collator
34 self.collator = TranslationCollator(pad_id=tokenizer.pad_id)
35
36 def create_train_loader(
37 self,
38 src_file: str = "train.de",
39 tgt_file: str = "train.en"
40 ) -> DataLoader:
41 """Create training dataloader with dynamic batching."""
42 dataset = TranslationDataset(
43 source_path=self.data_dir / src_file,
44 target_path=self.data_dir / tgt_file,
45 tokenizer=self.tokenizer,
46 max_length=self.max_length
47 )
48
49 sampler = TokenBucketSampler(
50 dataset,
51 max_tokens=self.max_tokens,
52 shuffle=True
53 )
54
55 return DataLoader(
56 dataset,
57 batch_sampler=None,
58 sampler=sampler,
59 collate_fn=self.collator,
60 num_workers=self.num_workers,
61 pin_memory=True,
62 )
63
64 def create_val_loader(
65 self,
66 src_file: str = "val.de",
67 tgt_file: str = "val.en",
68 batch_size: int = 32
69 ) -> DataLoader:
70 """Create validation dataloader with fixed batch size."""
71 dataset = TranslationDataset(
72 source_path=self.data_dir / src_file,
73 target_path=self.data_dir / tgt_file,
74 tokenizer=self.tokenizer,
75 max_length=self.max_length
76 )
77
78 return DataLoader(
79 dataset,
80 batch_size=batch_size,
81 shuffle=False,
82 collate_fn=self.collator,
83 num_workers=self.num_workers,
84 pin_memory=True,
85 )
86
87 def create_test_loader(
88 self,
89 src_file: str = "test_2016_flickr.de",
90 tgt_file: str = "test_2016_flickr.en",
91 batch_size: int = 32
92 ) -> DataLoader:
93 """Create test dataloader."""
94 return self.create_val_loader(src_file, tgt_file, batch_size)Usage Example
πpython
1# data_setup.py
2
3def setup_dataloaders(config):
4 """
5 Create all dataloaders for training.
6 """
7 # Load tokenizer
8 tokenizer = JointBPETokenizer.load(config.tokenizer_path)
9
10 # Create factory
11 factory = DataLoaderFactory(
12 tokenizer=tokenizer,
13 data_dir=config.data_dir,
14 max_tokens=config.max_tokens,
15 max_length=config.max_length,
16 num_workers=config.num_workers
17 )
18
19 # Create loaders
20 train_loader = factory.create_train_loader()
21 val_loader = factory.create_val_loader()
22 test_loader = factory.create_test_loader()
23
24 print(f"Train batches: {len(train_loader)}")
25 print(f"Val batches: {len(val_loader)}")
26 print(f"Test batches: {len(test_loader)}")
27
28 return train_loader, val_loader, test_loader
29
30
31# In training script:
32train_loader, val_loader, test_loader = setup_dataloaders(config)
33
34for epoch in range(num_epochs):
35 # Set epoch for shuffling
36 train_loader.sampler.set_epoch(epoch)
37
38 for batch in train_loader:
39 source = batch['source'].to(device)
40 source_mask = batch['source_mask'].to(device)
41 target_input = batch['target_input'].to(device)
42 target_output = batch['target_output'].to(device)
43
44 # Forward pass
45 logits = model(source, target_input, source_mask)
46
47 # Compute loss
48 loss = criterion(logits, target_output)
49
50 # Backward pass
51 loss.backward()
52 optimizer.step()4.5 Data Module
Complete Data Module Class
πpython
1from dataclasses import dataclass
2
3
4@dataclass
5class DataConfig:
6 """Data configuration."""
7 data_dir: str = "data/multi30k"
8 tokenizer_path: str = "data/tokenizer/tokenizer.json"
9 max_tokens: int = 4096
10 max_length: int = 128
11 val_batch_size: int = 32
12 num_workers: int = 2
13
14
15class Multi30kDataModule:
16 """
17 Complete data module for Multi30k translation.
18
19 Manages:
20 - Tokenizer loading
21 - Dataset creation
22 - DataLoader configuration
23 - Train/val/test splits
24
25 Args:
26 config: DataConfig instance
27
28 Example:
29 >>> dm = Multi30kDataModule(config)
30 >>> dm.setup()
31 >>> train_loader = dm.train_dataloader()
32 """
33
34 def __init__(self, config: DataConfig):
35 self.config = config
36 self.tokenizer = None
37 self.train_dataset = None
38 self.val_dataset = None
39 self.test_dataset = None
40
41 def setup(self):
42 """Load tokenizer and create datasets."""
43 # Load tokenizer
44 self.tokenizer = JointBPETokenizer.load(self.config.tokenizer_path)
45
46 print(f"Tokenizer vocab size: {len(self.tokenizer.token_to_id)}")
47
48 # Create datasets
49 data_dir = Path(self.config.data_dir)
50
51 self.train_dataset = TranslationDataset(
52 source_path=data_dir / "train.de",
53 target_path=data_dir / "train.en",
54 tokenizer=self.tokenizer,
55 max_length=self.config.max_length
56 )
57
58 self.val_dataset = TranslationDataset(
59 source_path=data_dir / "val.de",
60 target_path=data_dir / "val.en",
61 tokenizer=self.tokenizer,
62 max_length=self.config.max_length
63 )
64
65 self.test_dataset = TranslationDataset(
66 source_path=data_dir / "test_2016_flickr.de",
67 target_path=data_dir / "test_2016_flickr.en",
68 tokenizer=self.tokenizer,
69 max_length=self.config.max_length
70 )
71
72 print(f"Train: {len(self.train_dataset)} pairs")
73 print(f"Val: {len(self.val_dataset)} pairs")
74 print(f"Test: {len(self.test_dataset)} pairs")
75
76 def train_dataloader(self) -> DataLoader:
77 """Get training dataloader."""
78 collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
79
80 sampler = TokenBucketSampler(
81 self.train_dataset,
82 max_tokens=self.config.max_tokens,
83 shuffle=True
84 )
85
86 return DataLoader(
87 self.train_dataset,
88 sampler=sampler,
89 collate_fn=collator,
90 num_workers=self.config.num_workers,
91 pin_memory=True
92 )
93
94 def val_dataloader(self) -> DataLoader:
95 """Get validation dataloader."""
96 collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
97
98 return DataLoader(
99 self.val_dataset,
100 batch_size=self.config.val_batch_size,
101 shuffle=False,
102 collate_fn=collator,
103 num_workers=self.config.num_workers,
104 pin_memory=True
105 )
106
107 def test_dataloader(self) -> DataLoader:
108 """Get test dataloader."""
109 collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
110
111 return DataLoader(
112 self.test_dataset,
113 batch_size=self.config.val_batch_size,
114 shuffle=False,
115 collate_fn=collator,
116 num_workers=self.config.num_workers,
117 pin_memory=True
118 )
119
120 @property
121 def vocab_size(self) -> int:
122 """Get vocabulary size."""
123 return len(self.tokenizer.token_to_id)
124
125 @property
126 def pad_id(self) -> int:
127 """Get padding token ID."""
128 return self.tokenizer.pad_idBatch Contents
πtext
1BATCH CONTENTS:
2βββββββββββββββ
3
4source: [batch_size, max_src_len]
5 Source token IDs (no BOS/EOS)
6 Padded with pad_id
7
8source_mask: [batch_size, max_src_len]
9 True where not padding
10
11target_input: [batch_size, max_tgt_len - 1]
12 Target for decoder input
13 Starts with BOS, excludes last token
14
15target_output: [batch_size, max_tgt_len - 1]
16 Target labels for loss
17 Excludes BOS, ends with EOS
18
19target_mask: [batch_size, max_tgt_len - 1]
20 True where not padding4.6 Testing the Complete Pipeline
End-to-End Test
πpython
1def test_data_pipeline():
2 # Setup
3 config = DataConfig(
4 data_dir="data/multi30k",
5 tokenizer_path="data/tokenizer/tokenizer.json"
6 )
7
8 dm = Multi30kDataModule(config)
9 dm.setup()
10
11 # Test train loader
12 train_loader = dm.train_dataloader()
13
14 print("\nTesting train loader:")
15 batch = next(iter(train_loader))
16
17 print(f" Source shape: {batch['source'].shape}")
18 print(f" Target input shape: {batch['target_input'].shape}")
19 print(f" Target output shape: {batch['target_output'].shape}")
20
21 # Verify shapes match
22 assert batch['source'].shape[0] == batch['target_input'].shape[0]
23 assert batch['target_input'].shape == batch['target_output'].shape
24
25 # Verify masking
26 src_mask = batch['source_mask']
27 tgt_mask = batch['target_mask']
28
29 assert src_mask.sum() > 0, "Source mask is all False!"
30 assert tgt_mask.sum() > 0, "Target mask is all False!"
31
32 # Verify special tokens
33 pad_id = dm.pad_id
34 bos_id = dm.tokenizer.bos_id
35 eos_id = dm.tokenizer.eos_id
36
37 # Target input should start with BOS
38 first_tokens = batch['target_input'][:, 0]
39 assert (first_tokens == bos_id).all(), "Target input should start with BOS"
40
41 print("\nβ All tests passed!")Summary
Data Pipeline Components
| Component | Purpose |
|---|---|
| TranslationDataset | Load and tokenize parallel text |
| TranslationCollator | Batch and pad sequences |
| TokenBucketSampler | Dynamic batching by tokens |
| Multi30kDataModule | Complete data management |
Batch Format
| Key | Shape | Description |
|---|---|---|
| source | [B, S] | Source token IDs |
| source_mask | [B, S] | True where not padding |
| target_input | [B, T-1] | Decoder input (starts with BOS) |
| target_output | [B, T-1] | Labels (ends with EOS) |
| target_mask | [B, T-1] | True where not padding |
Recommended Configuration
πpython
1config = DataConfig(
2 data_dir="data/multi30k",
3 tokenizer_path="data/tokenizer/tokenizer.json",
4 max_tokens=4096, # ~64 sentences per batch
5 max_length=128, # Enough for Multi30k
6 val_batch_size=32, # Fixed for validation
7 num_workers=2 # Parallel loading
8)Chapter Summary
In this chapter, we set up everything needed to work with Multi30k:
- Dataset Overview: Understood Multi30k structure and statistics
- Preprocessing: Created German-English preprocessing pipeline
- Tokenization: Built joint BPE tokenizer with 8K vocabulary
- Data Loading: Created efficient PyTorch dataloaders
We're now ready to train our translation model!
Exercises
Implementation
- Add caching to avoid re-tokenizing on each epoch.
- Implement multi-GPU data distribution with DistributedSampler.
- Create a data visualization tool showing batch statistics.
Analysis
- Compare training speed with different max_tokens settings.
- Profile the data loading to identify bottlenecks.
Next Chapter Preview
In the next chapter, we'll Train the Translation Modelβputting together everything we've built to train a German-English translation system.