Chapter 15
20 min read
Section 69 of 75

Fine-tuning mBART for Translation

Pretrained Models

Introduction

This section provides a complete guide to fine-tuning mBART-50 on the Multi30k dataset for German-English translation. We'll cover data preparation, training configuration, and optimization techniques.


Loading mBART Model and Tokenizer

Model Setup

🐍python
1import torch
2import torch.nn as nn
3from torch.utils.data import Dataset, DataLoader
4from transformers import (
5    MBartForConditionalGeneration,
6    MBart50TokenizerFast,
7    get_scheduler
8)
9from typing import Dict, List, Optional, Tuple
10import os
11
12
13class MBartTranslator:
14    """
15    mBART-50 model wrapper for German-English translation.
16    """
17
18    def __init__(
19        self,
20        model_name: str = "facebook/mbart-large-50-many-to-many-mmt",
21        src_lang: str = "de_DE",
22        tgt_lang: str = "en_XX",
23        device: Optional[str] = None
24    ):
25        """
26        Initialize mBART translator.
27        """
28        self.model_name = model_name
29        self.src_lang = src_lang
30        self.tgt_lang = tgt_lang
31        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
32
33        print(f"Loading model: {model_name}")
34        print(f"Translation: {src_lang}{tgt_lang}")
35        print(f"Device: {self.device}")
36
37        # Load tokenizer
38        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
39        self.tokenizer.src_lang = src_lang
40
41        # Load model
42        self.model = MBartForConditionalGeneration.from_pretrained(model_name)
43        self.model.to(self.device)
44
45        # Get target language token ID for generation
46        self.forced_bos_token_id = self.tokenizer.lang_code_to_id[tgt_lang]
47
48        print(f"Model parameters: {self.model.num_parameters():,}")
49        print(f"Vocabulary size: {len(self.tokenizer)}")
50
51    def translate(
52        self,
53        text: str,
54        max_length: int = 128,
55        num_beams: int = 5,
56        **kwargs
57    ) -> str:
58        """
59        Translate a single text.
60        """
61        # Tokenize
62        encoded = self.tokenizer(
63            text,
64            return_tensors="pt",
65            max_length=max_length,
66            truncation=True
67        )
68        encoded = {k: v.to(self.device) for k, v in encoded.items()}
69
70        # Generate
71        self.model.eval()
72        with torch.no_grad():
73            generated = self.model.generate(
74                **encoded,
75                forced_bos_token_id=self.forced_bos_token_id,
76                max_length=max_length,
77                num_beams=num_beams,
78                **kwargs
79            )
80
81        # Decode
82        translation = self.tokenizer.decode(
83            generated[0],
84            skip_special_tokens=True
85        )
86
87        return translation

Preparing Multi30k Data for Fine-tuning

Dataset Class

🐍python
1class Multi30kDatasetForMbart(Dataset):
2    """
3    Multi30k dataset formatted for mBART fine-tuning.
4    """
5
6    def __init__(
7        self,
8        src_file: str,
9        tgt_file: str,
10        tokenizer: MBart50TokenizerFast,
11        src_lang: str = "de_DE",
12        tgt_lang: str = "en_XX",
13        max_length: int = 128
14    ):
15        """
16        Initialize dataset.
17        """
18        self.tokenizer = tokenizer
19        self.src_lang = src_lang
20        self.tgt_lang = tgt_lang
21        self.max_length = max_length
22
23        # Load data
24        with open(src_file, 'r', encoding='utf-8') as f:
25            self.src_texts = [line.strip() for line in f]
26
27        with open(tgt_file, 'r', encoding='utf-8') as f:
28            self.tgt_texts = [line.strip() for line in f]
29
30        assert len(self.src_texts) == len(self.tgt_texts)
31        print(f"Loaded {len(self.src_texts)} sentence pairs")
32
33    def __len__(self) -> int:
34        return len(self.src_texts)
35
36    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
37        """
38        Get a single example.
39        """
40        src_text = self.src_texts[idx]
41        tgt_text = self.tgt_texts[idx]
42
43        # Set source language
44        self.tokenizer.src_lang = self.src_lang
45
46        # Encode source
47        model_inputs = self.tokenizer(
48            src_text,
49            max_length=self.max_length,
50            truncation=True,
51            padding=False
52        )
53
54        # Encode target (with target language)
55        with self.tokenizer.as_target_tokenizer():
56            labels = self.tokenizer(
57                tgt_text,
58                max_length=self.max_length,
59                truncation=True,
60                padding=False
61            )
62
63        model_inputs["labels"] = labels["input_ids"]
64
65        return {
66            k: torch.tensor(v) for k, v in model_inputs.items()
67        }
68
69
70class MBartCollator:
71    """
72    Data collator for mBART that handles padding.
73    """
74
75    def __init__(
76        self,
77        tokenizer: MBart50TokenizerFast,
78        pad_to_multiple_of: Optional[int] = 8  # For efficient GPU
79    ):
80        self.tokenizer = tokenizer
81        self.pad_to_multiple_of = pad_to_multiple_of
82
83    def __call__(
84        self,
85        features: List[Dict[str, torch.Tensor]]
86    ) -> Dict[str, torch.Tensor]:
87        """
88        Collate features into a batch.
89        """
90        # Separate labels
91        labels = [f["labels"] for f in features]
92        input_features = [{k: v for k, v in f.items() if k != "labels"}
93                          for f in features]
94
95        # Pad inputs
96        batch = self.tokenizer.pad(
97            input_features,
98            padding=True,
99            pad_to_multiple_of=self.pad_to_multiple_of,
100            return_tensors="pt"
101        )
102
103        # Pad labels
104        max_label_length = max(len(l) for l in labels)
105        if self.pad_to_multiple_of:
106            max_label_length = (
107                (max_label_length + self.pad_to_multiple_of - 1)
108                // self.pad_to_multiple_of * self.pad_to_multiple_of
109            )
110
111        padded_labels = []
112        for label in labels:
113            padding_length = max_label_length - len(label)
114            # Use -100 for padding (ignored in loss)
115            padded = torch.cat([
116                label,
117                torch.full((padding_length,), -100, dtype=torch.long)
118            ])
119            padded_labels.append(padded)
120
121        batch["labels"] = torch.stack(padded_labels)
122
123        return batch

Fine-tuning Configuration

Training Parameters

Pre-trained models require different hyperparameters than training from scratch:

Learning Rate: Use much smaller learning rates (1e-5 to 5e-5) compared to training from scratch (1e-4 to 3e-4). Too high causes catastrophic forgetting of pre-trained knowledge. Recommended: 3e-5 with linear decay.

Number of Epochs: Pre-trained models converge much faster, typically needing only 3-10 epochs vs 30-50 from scratch. Watch for overfitting on small datasets. Recommended: 5 epochs for Multi30k.

Batch Size: Use gradient accumulation for larger effective batch sizes if memory is limited. Recommended: 16 with 2 accumulation steps = 32 effective batch size.

🐍python
1from dataclasses import dataclass
2
3@dataclass
4class MBartFinetuneConfig:
5    """
6    Configuration for mBART fine-tuning.
7    """
8    # Model
9    model_name: str = "facebook/mbart-large-50-many-to-many-mmt"
10    src_lang: str = "de_DE"
11    tgt_lang: str = "en_XX"
12
13    # Data
14    data_dir: str = "./data/multi30k"
15    max_length: int = 128
16
17    # Training
18    num_epochs: int = 5  # Much fewer than training from scratch
19    batch_size: int = 16
20    gradient_accumulation_steps: int = 2  # Effective batch = 32
21    learning_rate: float = 3e-5  # Smaller than from scratch
22    weight_decay: float = 0.01
23    warmup_ratio: float = 0.1  # 10% of training steps
24
25    # Learning rate schedule
26    lr_scheduler_type: str = "linear"  # or "cosine"
27
28    # Optimization
29    max_grad_norm: float = 1.0
30    fp16: bool = True  # Mixed precision
31
32    # Regularization
33    label_smoothing: float = 0.1
34    dropout: float = 0.1
35
36    # Evaluation
37    eval_steps: int = 500  # Evaluate every N steps
38    save_steps: int = 500  # Save every N steps
39    num_beams: int = 5
40
41    # Freezing strategy
42    freeze_encoder: bool = False  # Option to freeze encoder
43    freeze_embeddings: bool = False  # Option to freeze embeddings
44
45    # Output
46    output_dir: str = "./checkpoints/mbart-finetuned"

Freezing Strategies:

Option 1: Full Fine-tuning (Recommended) - Train all parameters. Best quality but needs more memory. Risk of overfitting on small data.

Option 2: Freeze Embeddings - Fix token embeddings. Useful when vocabulary is adequate. Slightly less overfitting.

Option 3: Freeze Encoder - Only train decoder. Faster training. Good when source language is well-represented in pre-training.

Option 4: Adapter Layers (Advanced) - Insert small trainable layers and freeze original model. Most efficient but requires implementation.

MethodBLEU (test)Training
Scratch (ours)30-352-3 hours
mBART zero-shot~15-250
mBART fine-tuned40-4530 min
+ data augmentation45-501 hour

Complete Fine-tuning Script

Training Loop

🐍python
1from tqdm import tqdm
2import json
3from pathlib import Path
4import sacrebleu
5
6
7class MBartFinetuner:
8    """
9    Fine-tune mBART for translation.
10    """
11
12    def __init__(self, config: MBartFinetuneConfig):
13        """
14        Initialize fine-tuner.
15        """
16        self.config = config
17        self.device = torch.device(
18            "cuda" if torch.cuda.is_available() else "cpu"
19        )
20
21        print(f"Device: {self.device}")
22        print(f"Model: {config.model_name}")
23
24        # Load tokenizer and model
25        self.tokenizer = MBart50TokenizerFast.from_pretrained(config.model_name)
26        self.tokenizer.src_lang = config.src_lang
27
28        self.model = MBartForConditionalGeneration.from_pretrained(
29            config.model_name
30        )
31        self.model.to(self.device)
32
33        # Get target language token ID
34        self.forced_bos_token_id = self.tokenizer.lang_code_to_id[config.tgt_lang]
35
36        # Apply freezing strategy
37        self._apply_freezing()
38
39        # Count parameters
40        total_params = sum(p.numel() for p in self.model.parameters())
41        trainable_params = sum(
42            p.numel() for p in self.model.parameters() if p.requires_grad
43        )
44        print(f"Total parameters: {total_params:,}")
45        print(f"Trainable parameters: {trainable_params:,}")
46        print(f"Frozen parameters: {total_params - trainable_params:,}")
47
48        # Setup output directory
49        self.output_dir = Path(config.output_dir)
50        self.output_dir.mkdir(parents=True, exist_ok=True)
51
52    def _apply_freezing(self):
53        """Apply freezing strategy based on config."""
54        if self.config.freeze_embeddings:
55            print("Freezing embeddings...")
56            for param in self.model.model.shared.parameters():
57                param.requires_grad = False
58
59        if self.config.freeze_encoder:
60            print("Freezing encoder...")
61            for param in self.model.model.encoder.parameters():
62                param.requires_grad = False
63
64    def train_epoch(self, epoch: int) -> float:
65        """
66        Train for one epoch.
67        """
68        self.model.train()
69        total_loss = 0.0
70        num_batches = 0
71
72        progress_bar = tqdm(
73            self.train_loader,
74            desc=f"Epoch {epoch + 1}/{self.config.num_epochs}"
75        )
76
77        self.optimizer.zero_grad()
78
79        for step, batch in enumerate(progress_bar):
80            # Move to device
81            batch = {k: v.to(self.device) for k, v in batch.items()}
82
83            # Forward pass (with mixed precision)
84            if self.config.fp16:
85                with torch.cuda.amp.autocast():
86                    outputs = self.model(
87                        **batch,
88                        use_cache=False
89                    )
90                    loss = outputs.loss / self.config.gradient_accumulation_steps
91            else:
92                outputs = self.model(
93                    **batch,
94                    use_cache=False
95                )
96                loss = outputs.loss / self.config.gradient_accumulation_steps
97
98            # Backward pass
99            if self.config.fp16:
100                self.scaler.scale(loss).backward()
101            else:
102                loss.backward()
103
104            total_loss += loss.item() * self.config.gradient_accumulation_steps
105            num_batches += 1
106
107            # Update weights
108            if (step + 1) % self.config.gradient_accumulation_steps == 0:
109                if self.config.fp16:
110                    self.scaler.unscale_(self.optimizer)
111
112                # Gradient clipping
113                torch.nn.utils.clip_grad_norm_(
114                    self.model.parameters(),
115                    self.config.max_grad_norm
116                )
117
118                if self.config.fp16:
119                    self.scaler.step(self.optimizer)
120                    self.scaler.update()
121                else:
122                    self.optimizer.step()
123
124                self.scheduler.step()
125                self.optimizer.zero_grad()
126                self.global_step += 1
127
128                # Update progress bar
129                progress_bar.set_postfix({
130                    'loss': f'{loss.item() * self.config.gradient_accumulation_steps:.4f}',
131                    'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
132                })
133
134        return total_loss / num_batches
135
136    def evaluate(self) -> Tuple[float, float]:
137        """
138        Evaluate on validation set.
139
140        Returns:
141            validation loss, BLEU score
142        """
143        self.model.eval()
144        total_loss = 0.0
145        num_batches = 0
146        all_predictions = []
147        all_references = []
148
149        with torch.no_grad():
150            for batch in tqdm(self.val_loader, desc="Evaluating"):
151                batch = {k: v.to(self.device) for k, v in batch.items()}
152
153                # Calculate loss
154                outputs = self.model(**batch)
155                total_loss += outputs.loss.item()
156                num_batches += 1
157
158                # Generate translations for BLEU
159                generated = self.model.generate(
160                    input_ids=batch["input_ids"],
161                    attention_mask=batch["attention_mask"],
162                    forced_bos_token_id=self.forced_bos_token_id,
163                    max_length=self.config.max_length,
164                    num_beams=self.config.num_beams
165                )
166
167                predictions = self.tokenizer.batch_decode(
168                    generated,
169                    skip_special_tokens=True
170                )
171
172                # Get references
173                labels = batch["labels"]
174                labels = labels.masked_fill(labels == -100, self.tokenizer.pad_token_id)
175                references = self.tokenizer.batch_decode(
176                    labels,
177                    skip_special_tokens=True
178                )
179
180                all_predictions.extend(predictions)
181                all_references.extend(references)
182
183        # Calculate BLEU
184        bleu = sacrebleu.corpus_bleu(
185            all_predictions,
186            [all_references]
187        ).score
188
189        return total_loss / num_batches, bleu

Using Hugging Face Trainer (Alternative)

Simplified Training

For a more automated approach, you can use the Hugging Face Trainer API:

🐍python
1from transformers import (
2    Seq2SeqTrainer,
3    Seq2SeqTrainingArguments,
4    DataCollatorForSeq2Seq
5)
6from datasets import Dataset as HFDataset
7import evaluate
8
9
10def finetune_with_hf_trainer():
11    """
12    Fine-tune mBART using Hugging Face Trainer.
13    """
14    # Load model and tokenizer
15    model_name = "facebook/mbart-large-50-many-to-many-mmt"
16    tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
17    model = MBartForConditionalGeneration.from_pretrained(model_name)
18
19    tokenizer.src_lang = "de_DE"
20    tokenizer.tgt_lang = "en_XX"
21
22    # Training arguments
23    training_args = Seq2SeqTrainingArguments(
24        output_dir="./checkpoints/mbart-hf-trainer",
25        evaluation_strategy="epoch",
26        save_strategy="epoch",
27        learning_rate=3e-5,
28        per_device_train_batch_size=16,
29        per_device_eval_batch_size=16,
30        gradient_accumulation_steps=2,
31        weight_decay=0.01,
32        save_total_limit=3,
33        num_train_epochs=5,
34        predict_with_generate=True,
35        generation_max_length=128,
36        generation_num_beams=5,
37        fp16=True,
38        load_best_model_at_end=True,
39        metric_for_best_model="bleu",
40        greater_is_better=True,
41        warmup_ratio=0.1,
42        logging_steps=100,
43        report_to="none"
44    )
45
46    # Data collator
47    data_collator = DataCollatorForSeq2Seq(
48        tokenizer=tokenizer,
49        model=model,
50        label_pad_token_id=-100
51    )
52
53    # Metric
54    sacrebleu_metric = evaluate.load("sacrebleu")
55
56    def compute_metrics(eval_preds):
57        preds, labels = eval_preds
58        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
59        labels = [[l for l in label if l != -100] for label in labels]
60        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
61
62        result = sacrebleu_metric.compute(
63            predictions=decoded_preds,
64            references=[[ref] for ref in decoded_labels]
65        )
66        return {"bleu": result["score"]}
67
68    # Initialize trainer
69    trainer = Seq2SeqTrainer(
70        model=model,
71        args=training_args,
72        train_dataset=train_dataset,
73        eval_dataset=val_dataset,
74        tokenizer=tokenizer,
75        data_collator=data_collator,
76        compute_metrics=compute_metrics,
77    )
78
79    # Train
80    trainer.train()
81
82    # Save final model
83    trainer.save_model("./checkpoints/mbart-final")
84    return trainer

Summary

StepStatusNotes
Load pre-trained modelmBART-50 many-to-many
Prepare Multi30k dataCustom dataset class
Configure trainingSmall LR, few epochs
Setup optimizer/schedulerAdamW + linear warmup
Training loopWith mixed precision
Evaluation with BLEUSacreBLEU
Save best modelBased on validation BLEU

Expected Results: Training with 5 epochs, batch size 32 (16 × 2 accumulation), learning rate 3e-5 takes ~30-45 minutes on GPU. Expected validation BLEU: ~42-45, test BLEU: ~40-44. This represents a +10-15 BLEU improvement compared to training from scratch!


Exercises

1. Fine-tune mBART on Multi30k and report your test BLEU score.

2. Compare fine-tuning all layers vs. freezing the encoder.

3. Experiment with different learning rates (1e-5, 3e-5, 5e-5).

4. Try fine-tuning for different numbers of epochs and plot the BLEU curve.

5. Use the HF Trainer implementation and compare results with custom training.

Loading comments...