Chapter 10
18 min read
Section 50 of 75

Data Loading and Batching for Translation

Training Pipeline

Introduction

Efficient data loading is crucial for training. This section covers how to create PyTorch datasets and dataloaders for translation, including dynamic batching, padding, and handling variable-length sequences.


Translation Dataset Structure

Parallel Corpus Format

πŸ“text
1Translation data typically comes as parallel sentences:
2
3Source (German):            Target (English):
4─────────────────           ─────────────────
5"Der Hund lΓ€uft"            "The dog runs"
6"Die Katze schlΓ€ft"         "The cat sleeps"
7"Ich liebe Programmieren"   "I love programming"
8
9Each line i in source file corresponds to line i in target file.

Multi30k Dataset

πŸ“text
1Multi30k: Standard dataset for German-English translation research
2- ~30,000 sentence pairs
3- Short sentences (average ~12 words)
4- Image captions in multiple languages
5- Commonly used for benchmarking
6
7Files:
8  train.de, train.en
9  val.de, val.en
10  test.de, test.en

Translation Dataset Class

Basic Implementation

🐍python
1import torch
2from torch.utils.data import Dataset, DataLoader
3from typing import List, Tuple, Optional, Dict
4import os
5
6
7class TranslationDataset(Dataset):
8    """
9    Dataset for machine translation.
10
11    Loads parallel source-target sentence pairs and tokenizes them.
12
13    Args:
14        source_file: Path to source language file
15        target_file: Path to target language file
16        tokenizer: Tokenizer with encode_source and encode_target methods
17        max_source_len: Maximum source sequence length
18        max_target_len: Maximum target sequence length
19
20    Example:
21        >>> dataset = TranslationDataset(
22        ...     'data/train.de', 'data/train.en',
23        ...     tokenizer, max_source_len=100, max_target_len=100
24        ... )
25        >>> src, tgt = dataset[0]
26    """
27
28    def __init__(
29        self,
30        source_file: str,
31        target_file: str,
32        tokenizer,
33        max_source_len: int = 128,
34        max_target_len: int = 128
35    ):
36        self.tokenizer = tokenizer
37        self.max_source_len = max_source_len
38        self.max_target_len = max_target_len
39
40        # Load sentence pairs
41        self.source_sentences = self._load_file(source_file)
42        self.target_sentences = self._load_file(target_file)
43
44        assert len(self.source_sentences) == len(self.target_sentences), \
45            "Source and target must have same number of sentences"
46
47        print(f"Loaded {len(self)} sentence pairs")
48
49    def _load_file(self, path: str) -> List[str]:
50        """Load sentences from file."""
51        with open(path, 'r', encoding='utf-8') as f:
52            sentences = [line.strip() for line in f]
53        return sentences
54
55    def __len__(self) -> int:
56        return len(self.source_sentences)
57
58    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
59        """
60        Get a tokenized source-target pair.
61
62        Returns:
63            source_ids: Tokenized source [src_len]
64            target_ids: Tokenized target with BOS/EOS [tgt_len]
65        """
66        source = self.source_sentences[idx]
67        target = self.target_sentences[idx]
68
69        # Tokenize source (no special tokens needed for encoder)
70        source_ids = self.tokenizer.encode_source(
71            source,
72            max_length=self.max_source_len
73        )
74
75        # Tokenize target with BOS and EOS
76        target_ids = self.tokenizer.encode_target(
77            target,
78            add_bos=True,
79            add_eos=True,
80            max_length=self.max_target_len
81        )
82
83        return (
84            torch.tensor(source_ids, dtype=torch.long),
85            torch.tensor(target_ids, dtype=torch.long)
86        )

Collation Function

Padding and Batching

🐍python
1class TranslationCollator:
2    """
3    Collate function for translation batches.
4
5    Pads sequences to the same length within each batch.
6
7    Args:
8        pad_id: Padding token ID
9        batch_first: Whether to return [batch, seq] (True) or [seq, batch]
10    """
11
12    def __init__(self, pad_id: int = 0, batch_first: bool = True):
13        self.pad_id = pad_id
14        self.batch_first = batch_first
15
16    def __call__(
17        self,
18        batch: List[Tuple[torch.Tensor, torch.Tensor]]
19    ) -> Dict[str, torch.Tensor]:
20        """
21        Collate a batch of source-target pairs.
22
23        Args:
24            batch: List of (source, target) tensor pairs
25
26        Returns:
27            Dictionary with:
28                - source_ids: [batch, max_src_len]
29                - target_ids: [batch, max_tgt_len]
30                - source_lengths: [batch]
31                - target_lengths: [batch]
32        """
33        sources, targets = zip(*batch)
34
35        # Get lengths
36        source_lengths = torch.tensor([len(s) for s in sources])
37        target_lengths = torch.tensor([len(t) for t in targets])
38
39        # Pad sequences
40        source_ids = self._pad_sequence(sources)
41        target_ids = self._pad_sequence(targets)
42
43        return {
44            'source_ids': source_ids,
45            'target_ids': target_ids,
46            'source_lengths': source_lengths,
47            'target_lengths': target_lengths
48        }
49
50    def _pad_sequence(
51        self,
52        sequences: Tuple[torch.Tensor, ...]
53    ) -> torch.Tensor:
54        """Pad sequences to same length."""
55        max_len = max(len(s) for s in sequences)
56        batch_size = len(sequences)
57
58        padded = torch.full(
59            (batch_size, max_len),
60            self.pad_id,
61            dtype=sequences[0].dtype
62        )
63
64        for i, seq in enumerate(sequences):
65            padded[i, :len(seq)] = seq
66
67        return padded

Dynamic Batching by Tokens

Efficient Batching

🐍python
1class TokenBucketBatcher:
2    """
3    Create batches with similar total token count.
4
5    Instead of fixed batch size, dynamically groups sequences
6    so each batch has approximately the same number of tokens.
7    This is more memory-efficient than fixed batch size.
8
9    Args:
10        dataset: Translation dataset
11        max_tokens: Maximum tokens per batch
12        drop_last: Whether to drop incomplete last batch
13    """
14
15    def __init__(
16        self,
17        dataset: TranslationDataset,
18        max_tokens: int = 4096,
19        drop_last: bool = False
20    ):
21        self.dataset = dataset
22        self.max_tokens = max_tokens
23        self.drop_last = drop_last
24
25        # Pre-compute lengths for bucketing
26        self.lengths = self._compute_lengths()
27        self.batches = self._create_batches()
28
29    def _compute_lengths(self) -> List[Tuple[int, int, int]]:
30        """Compute (idx, src_len, tgt_len) for all samples."""
31        lengths = []
32        for idx in range(len(self.dataset)):
33            src, tgt = self.dataset[idx]
34            lengths.append((idx, len(src), len(tgt)))
35        return lengths
36
37    def _create_batches(self) -> List[List[int]]:
38        """Create batches with similar token counts."""
39        # Sort by source length for efficient packing
40        sorted_indices = sorted(
41            self.lengths,
42            key=lambda x: (x[1], x[2])  # Sort by src_len, then tgt_len
43        )
44
45        batches = []
46        current_batch = []
47        current_tokens = 0
48
49        for idx, src_len, tgt_len in sorted_indices:
50            # Tokens for this sample: max of padded length in batch
51            sample_tokens = src_len + tgt_len
52
53            # Check if adding this sample would exceed limit
54            if current_batch:
55                new_max_src = max(
56                    max(self.lengths[i][1] for i in current_batch),
57                    src_len
58                )
59                new_max_tgt = max(
60                    max(self.lengths[i][2] for i in current_batch),
61                    tgt_len
62                )
63                estimated_tokens = (len(current_batch) + 1) * (new_max_src + new_max_tgt)
64
65                if estimated_tokens > self.max_tokens:
66                    # Start new batch
67                    batches.append(current_batch)
68                    current_batch = []
69
70            current_batch.append(idx)
71
72        # Don't forget last batch
73        if current_batch and (not self.drop_last or len(current_batch) >= 1):
74            batches.append(current_batch)
75
76        return batches
77
78    def __len__(self) -> int:
79        return len(self.batches)
80
81    def __iter__(self):
82        """Yield batches."""
83        import random
84        indices = list(range(len(self.batches)))
85        random.shuffle(indices)
86
87        for batch_idx in indices:
88            batch_indices = self.batches[batch_idx]
89            yield [self.dataset[i] for i in batch_indices]

Visualization

πŸ“text
1Fixed Batch Size (N=4):
2───────────────────────
3Batch 1: [Short, Short, Short, Short]  β†’ Mostly padding
4Batch 2: [Long, Long, Long, Long]      β†’ Memory overflow!
5
6β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
7β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β”‚ Seq 1
8β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β”‚ Seq 2
9β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β”‚ Seq 3
10β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ”‚ Seq 4 (long)
11β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
12Wasted compute: β–‘β–‘β–‘ (lots of padding!)
13
14
15Dynamic Batching (max_tokens=100):
16──────────────────────────────────
17Batch 1: [Short Γ— 10]              β†’ ~100 tokens total
18Batch 2: [Medium Γ— 5]              β†’ ~100 tokens total
19Batch 3: [Long Γ— 2]                β†’ ~100 tokens total
20
21β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
22β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β”‚ Seq 1       β”‚
23β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β”‚ Seq 2       β”‚
24β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β”‚ Seq 3       β”‚
25β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β”‚ Seq 4       β”‚  ~100 tokens
26β”‚β–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β”‚ Seq 5       β”‚
27β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
28
29β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
30β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β”‚ Seq 6         β”‚
31β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β”‚ Seq 7         β”‚  ~100 tokens
32β”‚β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β”‚ Seq 8         β”‚
33β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
34
35Much less wasted compute!

Complete DataLoader Setup

Production Implementation

🐍python
1def create_translation_dataloaders(
2    train_source: str,
3    train_target: str,
4    val_source: str,
5    val_target: str,
6    tokenizer,
7    batch_size: int = 32,
8    max_source_len: int = 128,
9    max_target_len: int = 128,
10    num_workers: int = 4,
11    pin_memory: bool = True
12) -> Tuple[DataLoader, DataLoader]:
13    """
14    Create training and validation dataloaders.
15
16    Args:
17        train_source: Path to training source file
18        train_target: Path to training target file
19        val_source: Path to validation source file
20        val_target: Path to validation target file
21        tokenizer: Tokenizer instance
22        batch_size: Batch size
23        max_source_len: Maximum source length
24        max_target_len: Maximum target length
25        num_workers: Number of data loading workers
26        pin_memory: Whether to pin memory for GPU transfer
27
28    Returns:
29        train_loader, val_loader
30    """
31    # Create datasets
32    train_dataset = TranslationDataset(
33        train_source, train_target,
34        tokenizer,
35        max_source_len, max_target_len
36    )
37
38    val_dataset = TranslationDataset(
39        val_source, val_target,
40        tokenizer,
41        max_source_len, max_target_len
42    )
43
44    # Create collator
45    collator = TranslationCollator(
46        pad_id=tokenizer.pad_token_id
47    )
48
49    # Create loaders
50    train_loader = DataLoader(
51        train_dataset,
52        batch_size=batch_size,
53        shuffle=True,
54        collate_fn=collator,
55        num_workers=num_workers,
56        pin_memory=pin_memory,
57        drop_last=True  # For consistent batch sizes during training
58    )
59
60    val_loader = DataLoader(
61        val_dataset,
62        batch_size=batch_size,
63        shuffle=False,
64        collate_fn=collator,
65        num_workers=num_workers,
66        pin_memory=pin_memory
67    )
68
69    return train_loader, val_loader
70
71
72class TranslationDataModule:
73    """
74    Complete data module for translation training.
75
76    Encapsulates all data loading logic.
77    """
78
79    def __init__(
80        self,
81        data_dir: str,
82        tokenizer,
83        batch_size: int = 32,
84        max_source_len: int = 128,
85        max_target_len: int = 128,
86        num_workers: int = 4
87    ):
88        self.data_dir = data_dir
89        self.tokenizer = tokenizer
90        self.batch_size = batch_size
91        self.max_source_len = max_source_len
92        self.max_target_len = max_target_len
93        self.num_workers = num_workers
94
95        self.train_loader = None
96        self.val_loader = None
97        self.test_loader = None
98
99    def setup(self):
100        """Set up datasets and dataloaders."""
101        collator = TranslationCollator(
102            pad_id=self.tokenizer.pad_token_id
103        )
104
105        # Training
106        train_dataset = TranslationDataset(
107            os.path.join(self.data_dir, 'train.de'),
108            os.path.join(self.data_dir, 'train.en'),
109            self.tokenizer,
110            self.max_source_len,
111            self.max_target_len
112        )
113
114        self.train_loader = DataLoader(
115            train_dataset,
116            batch_size=self.batch_size,
117            shuffle=True,
118            collate_fn=collator,
119            num_workers=self.num_workers,
120            pin_memory=True
121        )
122
123        # Validation
124        val_dataset = TranslationDataset(
125            os.path.join(self.data_dir, 'val.de'),
126            os.path.join(self.data_dir, 'val.en'),
127            self.tokenizer,
128            self.max_source_len,
129            self.max_target_len
130        )
131
132        self.val_loader = DataLoader(
133            val_dataset,
134            batch_size=self.batch_size,
135            shuffle=False,
136            collate_fn=collator,
137            num_workers=self.num_workers,
138            pin_memory=True
139        )
140
141        print(f"Train batches: {len(self.train_loader)}")
142        print(f"Val batches: {len(self.val_loader)}")
143
144    def train_dataloader(self) -> DataLoader:
145        return self.train_loader
146
147    def val_dataloader(self) -> DataLoader:
148        return self.val_loader

Summary

Key Components

ComponentPurpose
TranslationDatasetLoad and tokenize sentence pairs
TranslationCollatorPad batches to same length
TokenBucketBatcherDynamic batching by tokens
TranslationDataModuleEncapsulate all data logic

Best Practices

  • Dynamic batching for memory efficiency
  • Sort by length to minimize padding
  • Pre-tokenize if possible for speed
  • Pin memory for faster GPU transfer
  • Multiple workers for parallel loading

Exercises

Implementation

  • Add support for data augmentation (e.g., back-translation).
  • Implement curriculum learning (start with short sentences).
  • Add caching of tokenized data to disk.

Analysis

  • Compare training speed with different batch sizes.
  • Measure GPU utilization with different num_workers.

In the next section, we'll implement Label Smoothing Lossβ€”a regularization technique that prevents the model from becoming overconfident and improves generalization.

Loading comments...