Chapter 12
18 min read
Section 62 of 75

Data Loading Pipeline

Multi30k Dataset Setup

This section creates the complete data loading pipeline for training. We'll build PyTorch datasets and dataloaders that efficiently prepare batches for our translation model.


4.1 Translation Dataset

PyTorch Dataset Implementation

🐍python
1import torch
2from torch.utils.data import Dataset, DataLoader
3from typing import List, Dict, Optional, Tuple
4from pathlib import Path
5
6
7class TranslationDataset(Dataset):
8    """
9    PyTorch Dataset for machine translation.
10
11    Handles:
12    - Loading parallel text files
13    - Tokenization with special tokens
14    - Returning source-target pairs
15
16    Args:
17        source_path: Path to source language file
18        target_path: Path to target language file
19        tokenizer: Trained tokenizer instance
20        max_length: Maximum sequence length
21
22    Example:
23        >>> dataset = TranslationDataset('train.de', 'train.en', tokenizer)
24        >>> src, tgt = dataset[0]
25    """
26
27    def __init__(
28        self,
29        source_path: str,
30        target_path: str,
31        tokenizer,
32        max_length: int = 128
33    ):
34        self.tokenizer = tokenizer
35        self.max_length = max_length
36
37        # Load data
38        self.source_sentences = self._load_file(source_path)
39        self.target_sentences = self._load_file(target_path)
40
41        assert len(self.source_sentences) == len(self.target_sentences), \
42            "Source and target must have same number of sentences"
43
44        print(f"Loaded {len(self)} sentence pairs")
45
46    def _load_file(self, path: str) -> List[str]:
47        """Load sentences from file."""
48        with open(path, 'r', encoding='utf-8') as f:
49            return [line.strip() for line in f]
50
51    def __len__(self) -> int:
52        return len(self.source_sentences)
53
54    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
55        """
56        Get a single example.
57
58        Returns:
59            Dictionary with:
60                - source_ids: Source token IDs (no special tokens)
61                - target_ids: Target token IDs (with BOS, EOS)
62                - source_text: Original source text
63                - target_text: Original target text
64        """
65        source_text = self.source_sentences[idx]
66        target_text = self.target_sentences[idx]
67
68        # Tokenize source (no special tokens)
69        source_ids = self.tokenizer.encode(
70            source_text,
71            add_special_tokens=False
72        )
73
74        # Tokenize target (with BOS and EOS)
75        target_ids = self.tokenizer.encode(
76            target_text,
77            add_special_tokens=True
78        )
79
80        # Truncate if needed
81        if len(source_ids) > self.max_length:
82            source_ids = source_ids[:self.max_length]
83
84        if len(target_ids) > self.max_length:
85            target_ids = target_ids[:self.max_length - 1] + [self.tokenizer.eos_id]
86
87        return {
88            'source_ids': torch.tensor(source_ids, dtype=torch.long),
89            'target_ids': torch.tensor(target_ids, dtype=torch.long),
90            'source_text': source_text,
91            'target_text': target_text,
92        }

4.2 Collate Function

Batching with Padding

🐍python
1class TranslationCollator:
2    """
3    Collate function for translation batches.
4
5    Handles:
6    - Padding sequences to same length
7    - Creating attention masks
8    - Separating decoder input and labels
9
10    Args:
11        pad_id: Padding token ID
12        return_tensors: Return type ('pt' for PyTorch)
13    """
14
15    def __init__(self, pad_id: int = 0):
16        self.pad_id = pad_id
17
18    def __call__(
19        self,
20        batch: List[Dict[str, torch.Tensor]]
21    ) -> Dict[str, torch.Tensor]:
22        """
23        Collate a batch of examples.
24
25        Returns:
26            Dictionary with:
27                - source: Padded source IDs [batch, src_len]
28                - source_mask: Source attention mask [batch, src_len]
29                - target_input: Decoder input [batch, tgt_len-1]
30                - target_output: Labels [batch, tgt_len-1]
31                - target_mask: Target attention mask [batch, tgt_len-1]
32        """
33        # Extract sequences
34        source_ids = [item['source_ids'] for item in batch]
35        target_ids = [item['target_ids'] for item in batch]
36
37        # Pad source
38        source_padded = self._pad_sequence(source_ids)
39        source_mask = (source_padded != self.pad_id)
40
41        # Pad target
42        target_padded = self._pad_sequence(target_ids)
43
44        # Split target into input and output
45        # Input: [BOS, tok1, tok2, ...] (exclude last)
46        # Output: [tok1, tok2, ..., EOS] (exclude first)
47        target_input = target_padded[:, :-1]
48        target_output = target_padded[:, 1:]
49        target_mask = (target_input != self.pad_id)
50
51        return {
52            'source': source_padded,
53            'source_mask': source_mask,
54            'target_input': target_input,
55            'target_output': target_output,
56            'target_mask': target_mask,
57        }
58
59    def _pad_sequence(
60        self,
61        sequences: List[torch.Tensor]
62    ) -> torch.Tensor:
63        """Pad sequences to same length."""
64        max_len = max(len(seq) for seq in sequences)
65
66        padded = torch.full(
67            (len(sequences), max_len),
68            self.pad_id,
69            dtype=torch.long
70        )
71
72        for i, seq in enumerate(sequences):
73            padded[i, :len(seq)] = seq
74
75        return padded

Example test showing batch structure:

🐍python
1collator = TranslationCollator(pad_id=0)
2
3# Create mock batch
4batch = [
5    {
6        'source_ids': torch.tensor([10, 20, 30]),
7        'target_ids': torch.tensor([2, 100, 200, 3]),  # BOS...EOS
8    },
9    {
10        'source_ids': torch.tensor([10, 20, 30, 40, 50]),
11        'target_ids': torch.tensor([2, 100, 200, 300, 400, 3]),
12    },
13]
14
15result = collator(batch)
16
17print("\nBatch structure:")
18for key, value in result.items():
19    print(f"  {key}: shape {value.shape}")
20
21# Output:
22#   source: shape torch.Size([2, 5])
23#   source_mask: shape torch.Size([2, 5])
24#   target_input: shape torch.Size([2, 5])
25#   target_output: shape torch.Size([2, 5])
26#   target_mask: shape torch.Size([2, 5])

4.3 Dynamic Batching

Batching by Token Count

🐍python
1from torch.utils.data import Sampler
2import random
3
4
5class TokenBucketSampler(Sampler):
6    """
7    Sampler that creates batches based on total tokens.
8
9    Instead of fixed batch size, targets a maximum number
10    of tokens per batch. This ensures GPU memory efficiency.
11
12    Args:
13        dataset: Translation dataset
14        max_tokens: Maximum tokens per batch
15        shuffle: Whether to shuffle
16        seed: Random seed for reproducibility
17    """
18
19    def __init__(
20        self,
21        dataset: TranslationDataset,
22        max_tokens: int = 4096,
23        shuffle: bool = True,
24        seed: int = 42
25    ):
26        self.dataset = dataset
27        self.max_tokens = max_tokens
28        self.shuffle = shuffle
29        self.seed = seed
30        self.epoch = 0
31
32        # Pre-compute lengths
33        self.lengths = self._compute_lengths()
34
35        # Create batches
36        self.batches = self._create_batches()
37
38    def _compute_lengths(self) -> List[Tuple[int, int, int]]:
39        """Compute (index, src_len, tgt_len) for each example."""
40        lengths = []
41
42        for i in range(len(self.dataset)):
43            # Get lengths without loading full example
44            src_len = len(self.dataset.source_sentences[i].split())
45            tgt_len = len(self.dataset.target_sentences[i].split())
46            lengths.append((i, src_len, tgt_len))
47
48        return lengths
49
50    def _create_batches(self) -> List[List[int]]:
51        """Create batches based on token count."""
52        # Sort by length for efficient batching
53        sorted_lengths = sorted(
54            self.lengths,
55            key=lambda x: x[1] + x[2]
56        )
57
58        batches = []
59        current_batch = []
60        current_max_src = 0
61        current_max_tgt = 0
62
63        for idx, src_len, tgt_len in sorted_lengths:
64            # Estimate tokens with this example
65            new_max_src = max(current_max_src, src_len)
66            new_max_tgt = max(current_max_tgt, tgt_len)
67            batch_tokens = (len(current_batch) + 1) * (new_max_src + new_max_tgt)
68
69            if batch_tokens > self.max_tokens and current_batch:
70                # Start new batch
71                batches.append(current_batch)
72                current_batch = [idx]
73                current_max_src = src_len
74                current_max_tgt = tgt_len
75            else:
76                # Add to current batch
77                current_batch.append(idx)
78                current_max_src = new_max_src
79                current_max_tgt = new_max_tgt
80
81        # Don't forget last batch
82        if current_batch:
83            batches.append(current_batch)
84
85        return batches
86
87    def __iter__(self):
88        """Yield batch indices."""
89        if self.shuffle:
90            # Shuffle batches (not within batches)
91            random.seed(self.seed + self.epoch)
92            indices = list(range(len(self.batches)))
93            random.shuffle(indices)
94
95            for i in indices:
96                yield from self.batches[i]
97        else:
98            for batch in self.batches:
99                yield from batch
100
101    def __len__(self):
102        return len(self.dataset)
103
104    def set_epoch(self, epoch: int):
105        """Set epoch for shuffling."""
106        self.epoch = epoch

Benefits of Token Batching

πŸ“text
1FIXED BATCH SIZE vs TOKEN BATCHING:
2────────────────────────────────────
3
4Fixed batch size = 32:
5  Batch 1: 32 sentences of 5 words each  = 160 tokens
6  Batch 2: 32 sentences of 50 words each = 1600 tokens
7  β†’ Memory usage varies by 10x!
8
9Token batching (max_tokens=1000):
10  Batch 1: 200 short sentences = ~1000 tokens
11  Batch 2: 20 long sentences   = ~1000 tokens
12  β†’ Consistent memory usage!
13
14
15BENEFITS:
16─────────
17
181. GPU memory efficiency
19   - Avoid OOM on long sequences
20   - Better utilization on short sequences
21
222. Training stability
23   - Consistent gradient scale
24   - More predictable training
25
263. Speed
27   - Less padding waste
28   - Better parallelism

4.4 Complete DataLoader Setup

DataLoader Factory

🐍python
1from torch.utils.data import DataLoader
2
3
4class DataLoaderFactory:
5    """
6    Factory for creating translation dataloaders.
7
8    Creates train, validation, and test dataloaders with
9    appropriate settings for each.
10
11    Args:
12        tokenizer: Trained tokenizer
13        data_dir: Directory containing data files
14        max_tokens: Maximum tokens per batch
15        max_length: Maximum sequence length
16        num_workers: Number of data loading workers
17    """
18
19    def __init__(
20        self,
21        tokenizer,
22        data_dir: str,
23        max_tokens: int = 4096,
24        max_length: int = 128,
25        num_workers: int = 2
26    ):
27        self.tokenizer = tokenizer
28        self.data_dir = Path(data_dir)
29        self.max_tokens = max_tokens
30        self.max_length = max_length
31        self.num_workers = num_workers
32
33        # Create collator
34        self.collator = TranslationCollator(pad_id=tokenizer.pad_id)
35
36    def create_train_loader(
37        self,
38        src_file: str = "train.de",
39        tgt_file: str = "train.en"
40    ) -> DataLoader:
41        """Create training dataloader with dynamic batching."""
42        dataset = TranslationDataset(
43            source_path=self.data_dir / src_file,
44            target_path=self.data_dir / tgt_file,
45            tokenizer=self.tokenizer,
46            max_length=self.max_length
47        )
48
49        sampler = TokenBucketSampler(
50            dataset,
51            max_tokens=self.max_tokens,
52            shuffle=True
53        )
54
55        return DataLoader(
56            dataset,
57            batch_sampler=None,
58            sampler=sampler,
59            collate_fn=self.collator,
60            num_workers=self.num_workers,
61            pin_memory=True,
62        )
63
64    def create_val_loader(
65        self,
66        src_file: str = "val.de",
67        tgt_file: str = "val.en",
68        batch_size: int = 32
69    ) -> DataLoader:
70        """Create validation dataloader with fixed batch size."""
71        dataset = TranslationDataset(
72            source_path=self.data_dir / src_file,
73            target_path=self.data_dir / tgt_file,
74            tokenizer=self.tokenizer,
75            max_length=self.max_length
76        )
77
78        return DataLoader(
79            dataset,
80            batch_size=batch_size,
81            shuffle=False,
82            collate_fn=self.collator,
83            num_workers=self.num_workers,
84            pin_memory=True,
85        )
86
87    def create_test_loader(
88        self,
89        src_file: str = "test_2016_flickr.de",
90        tgt_file: str = "test_2016_flickr.en",
91        batch_size: int = 32
92    ) -> DataLoader:
93        """Create test dataloader."""
94        return self.create_val_loader(src_file, tgt_file, batch_size)

Usage Example

🐍python
1# data_setup.py
2
3def setup_dataloaders(config):
4    """
5    Create all dataloaders for training.
6    """
7    # Load tokenizer
8    tokenizer = JointBPETokenizer.load(config.tokenizer_path)
9
10    # Create factory
11    factory = DataLoaderFactory(
12        tokenizer=tokenizer,
13        data_dir=config.data_dir,
14        max_tokens=config.max_tokens,
15        max_length=config.max_length,
16        num_workers=config.num_workers
17    )
18
19    # Create loaders
20    train_loader = factory.create_train_loader()
21    val_loader = factory.create_val_loader()
22    test_loader = factory.create_test_loader()
23
24    print(f"Train batches: {len(train_loader)}")
25    print(f"Val batches: {len(val_loader)}")
26    print(f"Test batches: {len(test_loader)}")
27
28    return train_loader, val_loader, test_loader
29
30
31# In training script:
32train_loader, val_loader, test_loader = setup_dataloaders(config)
33
34for epoch in range(num_epochs):
35    # Set epoch for shuffling
36    train_loader.sampler.set_epoch(epoch)
37
38    for batch in train_loader:
39        source = batch['source'].to(device)
40        source_mask = batch['source_mask'].to(device)
41        target_input = batch['target_input'].to(device)
42        target_output = batch['target_output'].to(device)
43
44        # Forward pass
45        logits = model(source, target_input, source_mask)
46
47        # Compute loss
48        loss = criterion(logits, target_output)
49
50        # Backward pass
51        loss.backward()
52        optimizer.step()

4.5 Data Module

Complete Data Module Class

🐍python
1from dataclasses import dataclass
2
3
4@dataclass
5class DataConfig:
6    """Data configuration."""
7    data_dir: str = "data/multi30k"
8    tokenizer_path: str = "data/tokenizer/tokenizer.json"
9    max_tokens: int = 4096
10    max_length: int = 128
11    val_batch_size: int = 32
12    num_workers: int = 2
13
14
15class Multi30kDataModule:
16    """
17    Complete data module for Multi30k translation.
18
19    Manages:
20    - Tokenizer loading
21    - Dataset creation
22    - DataLoader configuration
23    - Train/val/test splits
24
25    Args:
26        config: DataConfig instance
27
28    Example:
29        >>> dm = Multi30kDataModule(config)
30        >>> dm.setup()
31        >>> train_loader = dm.train_dataloader()
32    """
33
34    def __init__(self, config: DataConfig):
35        self.config = config
36        self.tokenizer = None
37        self.train_dataset = None
38        self.val_dataset = None
39        self.test_dataset = None
40
41    def setup(self):
42        """Load tokenizer and create datasets."""
43        # Load tokenizer
44        self.tokenizer = JointBPETokenizer.load(self.config.tokenizer_path)
45
46        print(f"Tokenizer vocab size: {len(self.tokenizer.token_to_id)}")
47
48        # Create datasets
49        data_dir = Path(self.config.data_dir)
50
51        self.train_dataset = TranslationDataset(
52            source_path=data_dir / "train.de",
53            target_path=data_dir / "train.en",
54            tokenizer=self.tokenizer,
55            max_length=self.config.max_length
56        )
57
58        self.val_dataset = TranslationDataset(
59            source_path=data_dir / "val.de",
60            target_path=data_dir / "val.en",
61            tokenizer=self.tokenizer,
62            max_length=self.config.max_length
63        )
64
65        self.test_dataset = TranslationDataset(
66            source_path=data_dir / "test_2016_flickr.de",
67            target_path=data_dir / "test_2016_flickr.en",
68            tokenizer=self.tokenizer,
69            max_length=self.config.max_length
70        )
71
72        print(f"Train: {len(self.train_dataset)} pairs")
73        print(f"Val: {len(self.val_dataset)} pairs")
74        print(f"Test: {len(self.test_dataset)} pairs")
75
76    def train_dataloader(self) -> DataLoader:
77        """Get training dataloader."""
78        collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
79
80        sampler = TokenBucketSampler(
81            self.train_dataset,
82            max_tokens=self.config.max_tokens,
83            shuffle=True
84        )
85
86        return DataLoader(
87            self.train_dataset,
88            sampler=sampler,
89            collate_fn=collator,
90            num_workers=self.config.num_workers,
91            pin_memory=True
92        )
93
94    def val_dataloader(self) -> DataLoader:
95        """Get validation dataloader."""
96        collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
97
98        return DataLoader(
99            self.val_dataset,
100            batch_size=self.config.val_batch_size,
101            shuffle=False,
102            collate_fn=collator,
103            num_workers=self.config.num_workers,
104            pin_memory=True
105        )
106
107    def test_dataloader(self) -> DataLoader:
108        """Get test dataloader."""
109        collator = TranslationCollator(pad_id=self.tokenizer.pad_id)
110
111        return DataLoader(
112            self.test_dataset,
113            batch_size=self.config.val_batch_size,
114            shuffle=False,
115            collate_fn=collator,
116            num_workers=self.config.num_workers,
117            pin_memory=True
118        )
119
120    @property
121    def vocab_size(self) -> int:
122        """Get vocabulary size."""
123        return len(self.tokenizer.token_to_id)
124
125    @property
126    def pad_id(self) -> int:
127        """Get padding token ID."""
128        return self.tokenizer.pad_id

Batch Contents

πŸ“text
1BATCH CONTENTS:
2───────────────
3
4source:        [batch_size, max_src_len]
5               Source token IDs (no BOS/EOS)
6               Padded with pad_id
7
8source_mask:   [batch_size, max_src_len]
9               True where not padding
10
11target_input:  [batch_size, max_tgt_len - 1]
12               Target for decoder input
13               Starts with BOS, excludes last token
14
15target_output: [batch_size, max_tgt_len - 1]
16               Target labels for loss
17               Excludes BOS, ends with EOS
18
19target_mask:   [batch_size, max_tgt_len - 1]
20               True where not padding

4.6 Testing the Complete Pipeline

End-to-End Test

🐍python
1def test_data_pipeline():
2    # Setup
3    config = DataConfig(
4        data_dir="data/multi30k",
5        tokenizer_path="data/tokenizer/tokenizer.json"
6    )
7
8    dm = Multi30kDataModule(config)
9    dm.setup()
10
11    # Test train loader
12    train_loader = dm.train_dataloader()
13
14    print("\nTesting train loader:")
15    batch = next(iter(train_loader))
16
17    print(f"  Source shape: {batch['source'].shape}")
18    print(f"  Target input shape: {batch['target_input'].shape}")
19    print(f"  Target output shape: {batch['target_output'].shape}")
20
21    # Verify shapes match
22    assert batch['source'].shape[0] == batch['target_input'].shape[0]
23    assert batch['target_input'].shape == batch['target_output'].shape
24
25    # Verify masking
26    src_mask = batch['source_mask']
27    tgt_mask = batch['target_mask']
28
29    assert src_mask.sum() > 0, "Source mask is all False!"
30    assert tgt_mask.sum() > 0, "Target mask is all False!"
31
32    # Verify special tokens
33    pad_id = dm.pad_id
34    bos_id = dm.tokenizer.bos_id
35    eos_id = dm.tokenizer.eos_id
36
37    # Target input should start with BOS
38    first_tokens = batch['target_input'][:, 0]
39    assert (first_tokens == bos_id).all(), "Target input should start with BOS"
40
41    print("\nβœ“ All tests passed!")

Summary

Data Pipeline Components

ComponentPurpose
TranslationDatasetLoad and tokenize parallel text
TranslationCollatorBatch and pad sequences
TokenBucketSamplerDynamic batching by tokens
Multi30kDataModuleComplete data management

Batch Format

KeyShapeDescription
source[B, S]Source token IDs
source_mask[B, S]True where not padding
target_input[B, T-1]Decoder input (starts with BOS)
target_output[B, T-1]Labels (ends with EOS)
target_mask[B, T-1]True where not padding
🐍python
1config = DataConfig(
2    data_dir="data/multi30k",
3    tokenizer_path="data/tokenizer/tokenizer.json",
4    max_tokens=4096,    # ~64 sentences per batch
5    max_length=128,     # Enough for Multi30k
6    val_batch_size=32,  # Fixed for validation
7    num_workers=2       # Parallel loading
8)

Chapter Summary

In this chapter, we set up everything needed to work with Multi30k:

  1. Dataset Overview: Understood Multi30k structure and statistics
  2. Preprocessing: Created German-English preprocessing pipeline
  3. Tokenization: Built joint BPE tokenizer with 8K vocabulary
  4. Data Loading: Created efficient PyTorch dataloaders

We're now ready to train our translation model!


Exercises

Implementation

  1. Add caching to avoid re-tokenizing on each epoch.
  2. Implement multi-GPU data distribution with DistributedSampler.
  3. Create a data visualization tool showing batch statistics.

Analysis

  1. Compare training speed with different max_tokens settings.
  2. Profile the data loading to identify bottlenecks.

Next Chapter Preview

In the next chapter, we'll Train the Translation Modelβ€”putting together everything we've built to train a German-English translation system.

Loading comments...