Introduction
This section provides a complete guide to fine-tuning mBART-50 on the Multi30k dataset for German-English translation. We'll cover data preparation, training configuration, and optimization techniques.
Loading mBART Model and Tokenizer
Model Setup
1import torch
2import torch.nn as nn
3from torch.utils.data import Dataset, DataLoader
4from transformers import (
5 MBartForConditionalGeneration,
6 MBart50TokenizerFast,
7 get_scheduler
8)
9from typing import Dict, List, Optional, Tuple
10import os
11
12
13class MBartTranslator:
14 """
15 mBART-50 model wrapper for German-English translation.
16 """
17
18 def __init__(
19 self,
20 model_name: str = "facebook/mbart-large-50-many-to-many-mmt",
21 src_lang: str = "de_DE",
22 tgt_lang: str = "en_XX",
23 device: Optional[str] = None
24 ):
25 """
26 Initialize mBART translator.
27 """
28 self.model_name = model_name
29 self.src_lang = src_lang
30 self.tgt_lang = tgt_lang
31 self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
32
33 print(f"Loading model: {model_name}")
34 print(f"Translation: {src_lang} → {tgt_lang}")
35 print(f"Device: {self.device}")
36
37 # Load tokenizer
38 self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
39 self.tokenizer.src_lang = src_lang
40
41 # Load model
42 self.model = MBartForConditionalGeneration.from_pretrained(model_name)
43 self.model.to(self.device)
44
45 # Get target language token ID for generation
46 self.forced_bos_token_id = self.tokenizer.lang_code_to_id[tgt_lang]
47
48 print(f"Model parameters: {self.model.num_parameters():,}")
49 print(f"Vocabulary size: {len(self.tokenizer)}")
50
51 def translate(
52 self,
53 text: str,
54 max_length: int = 128,
55 num_beams: int = 5,
56 **kwargs
57 ) -> str:
58 """
59 Translate a single text.
60 """
61 # Tokenize
62 encoded = self.tokenizer(
63 text,
64 return_tensors="pt",
65 max_length=max_length,
66 truncation=True
67 )
68 encoded = {k: v.to(self.device) for k, v in encoded.items()}
69
70 # Generate
71 self.model.eval()
72 with torch.no_grad():
73 generated = self.model.generate(
74 **encoded,
75 forced_bos_token_id=self.forced_bos_token_id,
76 max_length=max_length,
77 num_beams=num_beams,
78 **kwargs
79 )
80
81 # Decode
82 translation = self.tokenizer.decode(
83 generated[0],
84 skip_special_tokens=True
85 )
86
87 return translationPreparing Multi30k Data for Fine-tuning
Dataset Class
1class Multi30kDatasetForMbart(Dataset):
2 """
3 Multi30k dataset formatted for mBART fine-tuning.
4 """
5
6 def __init__(
7 self,
8 src_file: str,
9 tgt_file: str,
10 tokenizer: MBart50TokenizerFast,
11 src_lang: str = "de_DE",
12 tgt_lang: str = "en_XX",
13 max_length: int = 128
14 ):
15 """
16 Initialize dataset.
17 """
18 self.tokenizer = tokenizer
19 self.src_lang = src_lang
20 self.tgt_lang = tgt_lang
21 self.max_length = max_length
22
23 # Load data
24 with open(src_file, 'r', encoding='utf-8') as f:
25 self.src_texts = [line.strip() for line in f]
26
27 with open(tgt_file, 'r', encoding='utf-8') as f:
28 self.tgt_texts = [line.strip() for line in f]
29
30 assert len(self.src_texts) == len(self.tgt_texts)
31 print(f"Loaded {len(self.src_texts)} sentence pairs")
32
33 def __len__(self) -> int:
34 return len(self.src_texts)
35
36 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
37 """
38 Get a single example.
39 """
40 src_text = self.src_texts[idx]
41 tgt_text = self.tgt_texts[idx]
42
43 # Set source language
44 self.tokenizer.src_lang = self.src_lang
45
46 # Encode source
47 model_inputs = self.tokenizer(
48 src_text,
49 max_length=self.max_length,
50 truncation=True,
51 padding=False
52 )
53
54 # Encode target (with target language)
55 with self.tokenizer.as_target_tokenizer():
56 labels = self.tokenizer(
57 tgt_text,
58 max_length=self.max_length,
59 truncation=True,
60 padding=False
61 )
62
63 model_inputs["labels"] = labels["input_ids"]
64
65 return {
66 k: torch.tensor(v) for k, v in model_inputs.items()
67 }
68
69
70class MBartCollator:
71 """
72 Data collator for mBART that handles padding.
73 """
74
75 def __init__(
76 self,
77 tokenizer: MBart50TokenizerFast,
78 pad_to_multiple_of: Optional[int] = 8 # For efficient GPU
79 ):
80 self.tokenizer = tokenizer
81 self.pad_to_multiple_of = pad_to_multiple_of
82
83 def __call__(
84 self,
85 features: List[Dict[str, torch.Tensor]]
86 ) -> Dict[str, torch.Tensor]:
87 """
88 Collate features into a batch.
89 """
90 # Separate labels
91 labels = [f["labels"] for f in features]
92 input_features = [{k: v for k, v in f.items() if k != "labels"}
93 for f in features]
94
95 # Pad inputs
96 batch = self.tokenizer.pad(
97 input_features,
98 padding=True,
99 pad_to_multiple_of=self.pad_to_multiple_of,
100 return_tensors="pt"
101 )
102
103 # Pad labels
104 max_label_length = max(len(l) for l in labels)
105 if self.pad_to_multiple_of:
106 max_label_length = (
107 (max_label_length + self.pad_to_multiple_of - 1)
108 // self.pad_to_multiple_of * self.pad_to_multiple_of
109 )
110
111 padded_labels = []
112 for label in labels:
113 padding_length = max_label_length - len(label)
114 # Use -100 for padding (ignored in loss)
115 padded = torch.cat([
116 label,
117 torch.full((padding_length,), -100, dtype=torch.long)
118 ])
119 padded_labels.append(padded)
120
121 batch["labels"] = torch.stack(padded_labels)
122
123 return batchFine-tuning Configuration
Training Parameters
Pre-trained models require different hyperparameters than training from scratch:
Learning Rate: Use much smaller learning rates (1e-5 to 5e-5) compared to training from scratch (1e-4 to 3e-4). Too high causes catastrophic forgetting of pre-trained knowledge. Recommended: 3e-5 with linear decay.
Number of Epochs: Pre-trained models converge much faster, typically needing only 3-10 epochs vs 30-50 from scratch. Watch for overfitting on small datasets. Recommended: 5 epochs for Multi30k.
Batch Size: Use gradient accumulation for larger effective batch sizes if memory is limited. Recommended: 16 with 2 accumulation steps = 32 effective batch size.
1from dataclasses import dataclass
2
3@dataclass
4class MBartFinetuneConfig:
5 """
6 Configuration for mBART fine-tuning.
7 """
8 # Model
9 model_name: str = "facebook/mbart-large-50-many-to-many-mmt"
10 src_lang: str = "de_DE"
11 tgt_lang: str = "en_XX"
12
13 # Data
14 data_dir: str = "./data/multi30k"
15 max_length: int = 128
16
17 # Training
18 num_epochs: int = 5 # Much fewer than training from scratch
19 batch_size: int = 16
20 gradient_accumulation_steps: int = 2 # Effective batch = 32
21 learning_rate: float = 3e-5 # Smaller than from scratch
22 weight_decay: float = 0.01
23 warmup_ratio: float = 0.1 # 10% of training steps
24
25 # Learning rate schedule
26 lr_scheduler_type: str = "linear" # or "cosine"
27
28 # Optimization
29 max_grad_norm: float = 1.0
30 fp16: bool = True # Mixed precision
31
32 # Regularization
33 label_smoothing: float = 0.1
34 dropout: float = 0.1
35
36 # Evaluation
37 eval_steps: int = 500 # Evaluate every N steps
38 save_steps: int = 500 # Save every N steps
39 num_beams: int = 5
40
41 # Freezing strategy
42 freeze_encoder: bool = False # Option to freeze encoder
43 freeze_embeddings: bool = False # Option to freeze embeddings
44
45 # Output
46 output_dir: str = "./checkpoints/mbart-finetuned"Freezing Strategies:
Option 1: Full Fine-tuning (Recommended) - Train all parameters. Best quality but needs more memory. Risk of overfitting on small data.
Option 2: Freeze Embeddings - Fix token embeddings. Useful when vocabulary is adequate. Slightly less overfitting.
Option 3: Freeze Encoder - Only train decoder. Faster training. Good when source language is well-represented in pre-training.
Option 4: Adapter Layers (Advanced) - Insert small trainable layers and freeze original model. Most efficient but requires implementation.
| Method | BLEU (test) | Training |
|---|---|---|
| Scratch (ours) | 30-35 | 2-3 hours |
| mBART zero-shot | ~15-25 | 0 |
| mBART fine-tuned | 40-45 | 30 min |
| + data augmentation | 45-50 | 1 hour |
Complete Fine-tuning Script
Training Loop
1from tqdm import tqdm
2import json
3from pathlib import Path
4import sacrebleu
5
6
7class MBartFinetuner:
8 """
9 Fine-tune mBART for translation.
10 """
11
12 def __init__(self, config: MBartFinetuneConfig):
13 """
14 Initialize fine-tuner.
15 """
16 self.config = config
17 self.device = torch.device(
18 "cuda" if torch.cuda.is_available() else "cpu"
19 )
20
21 print(f"Device: {self.device}")
22 print(f"Model: {config.model_name}")
23
24 # Load tokenizer and model
25 self.tokenizer = MBart50TokenizerFast.from_pretrained(config.model_name)
26 self.tokenizer.src_lang = config.src_lang
27
28 self.model = MBartForConditionalGeneration.from_pretrained(
29 config.model_name
30 )
31 self.model.to(self.device)
32
33 # Get target language token ID
34 self.forced_bos_token_id = self.tokenizer.lang_code_to_id[config.tgt_lang]
35
36 # Apply freezing strategy
37 self._apply_freezing()
38
39 # Count parameters
40 total_params = sum(p.numel() for p in self.model.parameters())
41 trainable_params = sum(
42 p.numel() for p in self.model.parameters() if p.requires_grad
43 )
44 print(f"Total parameters: {total_params:,}")
45 print(f"Trainable parameters: {trainable_params:,}")
46 print(f"Frozen parameters: {total_params - trainable_params:,}")
47
48 # Setup output directory
49 self.output_dir = Path(config.output_dir)
50 self.output_dir.mkdir(parents=True, exist_ok=True)
51
52 def _apply_freezing(self):
53 """Apply freezing strategy based on config."""
54 if self.config.freeze_embeddings:
55 print("Freezing embeddings...")
56 for param in self.model.model.shared.parameters():
57 param.requires_grad = False
58
59 if self.config.freeze_encoder:
60 print("Freezing encoder...")
61 for param in self.model.model.encoder.parameters():
62 param.requires_grad = False
63
64 def train_epoch(self, epoch: int) -> float:
65 """
66 Train for one epoch.
67 """
68 self.model.train()
69 total_loss = 0.0
70 num_batches = 0
71
72 progress_bar = tqdm(
73 self.train_loader,
74 desc=f"Epoch {epoch + 1}/{self.config.num_epochs}"
75 )
76
77 self.optimizer.zero_grad()
78
79 for step, batch in enumerate(progress_bar):
80 # Move to device
81 batch = {k: v.to(self.device) for k, v in batch.items()}
82
83 # Forward pass (with mixed precision)
84 if self.config.fp16:
85 with torch.cuda.amp.autocast():
86 outputs = self.model(
87 **batch,
88 use_cache=False
89 )
90 loss = outputs.loss / self.config.gradient_accumulation_steps
91 else:
92 outputs = self.model(
93 **batch,
94 use_cache=False
95 )
96 loss = outputs.loss / self.config.gradient_accumulation_steps
97
98 # Backward pass
99 if self.config.fp16:
100 self.scaler.scale(loss).backward()
101 else:
102 loss.backward()
103
104 total_loss += loss.item() * self.config.gradient_accumulation_steps
105 num_batches += 1
106
107 # Update weights
108 if (step + 1) % self.config.gradient_accumulation_steps == 0:
109 if self.config.fp16:
110 self.scaler.unscale_(self.optimizer)
111
112 # Gradient clipping
113 torch.nn.utils.clip_grad_norm_(
114 self.model.parameters(),
115 self.config.max_grad_norm
116 )
117
118 if self.config.fp16:
119 self.scaler.step(self.optimizer)
120 self.scaler.update()
121 else:
122 self.optimizer.step()
123
124 self.scheduler.step()
125 self.optimizer.zero_grad()
126 self.global_step += 1
127
128 # Update progress bar
129 progress_bar.set_postfix({
130 'loss': f'{loss.item() * self.config.gradient_accumulation_steps:.4f}',
131 'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
132 })
133
134 return total_loss / num_batches
135
136 def evaluate(self) -> Tuple[float, float]:
137 """
138 Evaluate on validation set.
139
140 Returns:
141 validation loss, BLEU score
142 """
143 self.model.eval()
144 total_loss = 0.0
145 num_batches = 0
146 all_predictions = []
147 all_references = []
148
149 with torch.no_grad():
150 for batch in tqdm(self.val_loader, desc="Evaluating"):
151 batch = {k: v.to(self.device) for k, v in batch.items()}
152
153 # Calculate loss
154 outputs = self.model(**batch)
155 total_loss += outputs.loss.item()
156 num_batches += 1
157
158 # Generate translations for BLEU
159 generated = self.model.generate(
160 input_ids=batch["input_ids"],
161 attention_mask=batch["attention_mask"],
162 forced_bos_token_id=self.forced_bos_token_id,
163 max_length=self.config.max_length,
164 num_beams=self.config.num_beams
165 )
166
167 predictions = self.tokenizer.batch_decode(
168 generated,
169 skip_special_tokens=True
170 )
171
172 # Get references
173 labels = batch["labels"]
174 labels = labels.masked_fill(labels == -100, self.tokenizer.pad_token_id)
175 references = self.tokenizer.batch_decode(
176 labels,
177 skip_special_tokens=True
178 )
179
180 all_predictions.extend(predictions)
181 all_references.extend(references)
182
183 # Calculate BLEU
184 bleu = sacrebleu.corpus_bleu(
185 all_predictions,
186 [all_references]
187 ).score
188
189 return total_loss / num_batches, bleuUsing Hugging Face Trainer (Alternative)
Simplified Training
For a more automated approach, you can use the Hugging Face Trainer API:
1from transformers import (
2 Seq2SeqTrainer,
3 Seq2SeqTrainingArguments,
4 DataCollatorForSeq2Seq
5)
6from datasets import Dataset as HFDataset
7import evaluate
8
9
10def finetune_with_hf_trainer():
11 """
12 Fine-tune mBART using Hugging Face Trainer.
13 """
14 # Load model and tokenizer
15 model_name = "facebook/mbart-large-50-many-to-many-mmt"
16 tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
17 model = MBartForConditionalGeneration.from_pretrained(model_name)
18
19 tokenizer.src_lang = "de_DE"
20 tokenizer.tgt_lang = "en_XX"
21
22 # Training arguments
23 training_args = Seq2SeqTrainingArguments(
24 output_dir="./checkpoints/mbart-hf-trainer",
25 evaluation_strategy="epoch",
26 save_strategy="epoch",
27 learning_rate=3e-5,
28 per_device_train_batch_size=16,
29 per_device_eval_batch_size=16,
30 gradient_accumulation_steps=2,
31 weight_decay=0.01,
32 save_total_limit=3,
33 num_train_epochs=5,
34 predict_with_generate=True,
35 generation_max_length=128,
36 generation_num_beams=5,
37 fp16=True,
38 load_best_model_at_end=True,
39 metric_for_best_model="bleu",
40 greater_is_better=True,
41 warmup_ratio=0.1,
42 logging_steps=100,
43 report_to="none"
44 )
45
46 # Data collator
47 data_collator = DataCollatorForSeq2Seq(
48 tokenizer=tokenizer,
49 model=model,
50 label_pad_token_id=-100
51 )
52
53 # Metric
54 sacrebleu_metric = evaluate.load("sacrebleu")
55
56 def compute_metrics(eval_preds):
57 preds, labels = eval_preds
58 decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
59 labels = [[l for l in label if l != -100] for label in labels]
60 decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
61
62 result = sacrebleu_metric.compute(
63 predictions=decoded_preds,
64 references=[[ref] for ref in decoded_labels]
65 )
66 return {"bleu": result["score"]}
67
68 # Initialize trainer
69 trainer = Seq2SeqTrainer(
70 model=model,
71 args=training_args,
72 train_dataset=train_dataset,
73 eval_dataset=val_dataset,
74 tokenizer=tokenizer,
75 data_collator=data_collator,
76 compute_metrics=compute_metrics,
77 )
78
79 # Train
80 trainer.train()
81
82 # Save final model
83 trainer.save_model("./checkpoints/mbart-final")
84 return trainerSummary
| Step | Status | Notes |
|---|---|---|
| Load pre-trained model | ✓ | mBART-50 many-to-many |
| Prepare Multi30k data | ✓ | Custom dataset class |
| Configure training | ✓ | Small LR, few epochs |
| Setup optimizer/scheduler | ✓ | AdamW + linear warmup |
| Training loop | ✓ | With mixed precision |
| Evaluation with BLEU | ✓ | SacreBLEU |
| Save best model | ✓ | Based on validation BLEU |
Expected Results: Training with 5 epochs, batch size 32 (16 × 2 accumulation), learning rate 3e-5 takes ~30-45 minutes on GPU. Expected validation BLEU: ~42-45, test BLEU: ~40-44. This represents a +10-15 BLEU improvement compared to training from scratch!
Exercises
1. Fine-tune mBART on Multi30k and report your test BLEU score.
2. Compare fine-tuning all layers vs. freezing the encoder.
3. Experiment with different learning rates (1e-5, 3e-5, 5e-5).
4. Try fine-tuning for different numbers of epochs and plot the BLEU curve.
5. Use the HF Trainer implementation and compare results with custom training.