Chapter 15
25 min read
Section 70 of 75

Advanced Fine-tuning Techniques

Pretrained Models

Introduction

This section covers advanced techniques to improve fine-tuning efficiency and performance, including parameter-efficient methods (LoRA, Adapters), data augmentation, and multi-task learning.


Parameter-Efficient Fine-tuning (PEFT)

Why Parameter-Efficient Methods?

Full fine-tuning of large models like mBART has significant drawbacks:

The Problem: Full fine-tuning requires training ~610 million parameters, storing ~2.4 GB model weights per task, high GPU memory requirements, and carries the risk of catastrophic forgetting. For multiple tasks, this multiplies linearly - 3 tasks would require 7.2 GB of storage.

The Solution: Parameter-efficient fine-tuning trains only a small subset of parameters (0.5-2%), stores tiny adapters (~10-50 MB) per task, shares the same base model across tasks, and often matches full fine-tuning quality!

MethodParamsMemoryQualitySpeedComplexity
Full FT100%HighBestSlowLow
Adapters2-5%MediumGoodMediumMedium
LoRA0.5-2%LowGoodFastLow
Prefix Tuning0.1%LowDecentFastMedium
Prompt Tuning<0.1%Very LowVariesVery FastLow

LoRA: Low-Rank Adaptation

LoRA Implementation

LoRA (Low-Rank Adaptation) decomposes weight updates into low-rank matrices. Instead of updating W' = W + ΔW (full update), LoRA uses ΔW = BA where B and A are low-rank matrices with rank r << min(d, k).

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional, List
4import math
5
6
7class LoRALayer(nn.Module):
8    """
9    Low-Rank Adaptation layer.
10
11    Instead of W' = W + ΔW (full update),
12    LoRA decomposes: ΔW = BA where B and A are low-rank.
13
14    Forward: h = Wx + BAx
15    """
16
17    def __init__(
18        self,
19        in_features: int,
20        out_features: int,
21        rank: int = 8,
22        alpha: float = 16.0,
23        dropout: float = 0.0
24    ):
25        """
26        Initialize LoRA layer.
27
28        Args:
29            in_features: Input dimension
30            out_features: Output dimension
31            rank: LoRA rank (r)
32            alpha: Scaling factor
33            dropout: Dropout rate on LoRA path
34        """
35        super().__init__()
36
37        self.rank = rank
38        self.alpha = alpha
39        self.scaling = alpha / rank
40
41        # LoRA matrices
42        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
43        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
44
45        # Dropout on LoRA path
46        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
48        # Initialize A with Kaiming, B with zeros
49        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
50        nn.init.zeros_(self.lora_B)
51
52    def forward(self, x: torch.Tensor) -> torch.Tensor:
53        """
54        Compute LoRA update: BAx * scaling
55        """
56        return self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
57
58
59class LinearWithLoRA(nn.Module):
60    """
61    Linear layer with LoRA adaptation.
62
63    Combines frozen pre-trained weights with trainable LoRA update.
64    """
65
66    def __init__(
67        self,
68        linear: nn.Linear,
69        rank: int = 8,
70        alpha: float = 16.0,
71        dropout: float = 0.0
72    ):
73        super().__init__()
74
75        self.linear = linear
76        self.lora = LoRALayer(
77            in_features=linear.in_features,
78            out_features=linear.out_features,
79            rank=rank,
80            alpha=alpha,
81            dropout=dropout
82        )
83
84        # Freeze original weights
85        self.linear.weight.requires_grad = False
86        if self.linear.bias is not None:
87            self.linear.bias.requires_grad = False
88
89    def forward(self, x: torch.Tensor) -> torch.Tensor:
90        """
91        Forward: original + LoRA update.
92        """
93        return self.linear(x) + self.lora(x)
94
95    def merge_weights(self):
96        """Merge LoRA weights into linear layer (for inference)."""
97        with torch.no_grad():
98            delta_W = (self.lora.lora_B @ self.lora.lora_A) * self.lora.scaling
99            self.linear.weight.add_(delta_W)
100        return self.linear

Memory Savings: For a 512×512 linear layer with rank 8, full fine-tuning requires 262,144 parameters while LoRA only requires 8,192 parameters - a 96.9% reduction!


Using PEFT Library

Easy LoRA with Hugging Face PEFT

The Hugging Face PEFT library makes it easy to apply LoRA to any transformer model:

🐍python
1# Installation
2# pip install peft>=0.4.0
3
4from peft import (
5    get_peft_model,
6    LoraConfig,
7    TaskType,
8    PeftModel
9)
10from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
11
12# Load base model
13model_name = "facebook/mbart-large-50-many-to-many-mmt"
14model = MBartForConditionalGeneration.from_pretrained(model_name)
15tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
16
17# Configure LoRA
18lora_config = LoraConfig(
19    task_type=TaskType.SEQ_2_SEQ_LM,
20    r=8,                          # LoRA rank
21    lora_alpha=16,                # Scaling factor
22    lora_dropout=0.1,             # Dropout
23    target_modules=[              # Which modules to adapt
24        "q_proj",                 # Query projection
25        "v_proj",                 # Value projection
26        # Can also add:
27        # "k_proj",               # Key projection
28        # "out_proj",             # Output projection
29        # "fc1", "fc2",           # FFN layers
30    ],
31    bias="none",                  # Don't train biases
32    modules_to_save=None,         # Additional modules to train
33)
34
35# Apply LoRA
36model = get_peft_model(model, lora_config)
37
38# Check trainable parameters
39model.print_trainable_parameters()
40# Output: trainable params: 3,670,016 || all params: 613,924,864 || trainable%: 0.60%
41
42# Training is the same as before!
43
44# Save only LoRA weights (tiny!)
45model.save_pretrained("./checkpoints/mbart-lora")
46# Saves ~15MB instead of ~2.4GB
47
48# Load LoRA weights
49base_model = MBartForConditionalGeneration.from_pretrained(model_name)
50model = PeftModel.from_pretrained(base_model, "./checkpoints/mbart-lora")
51
52# For inference, optionally merge weights
53model = model.merge_and_unload()  # Now a regular model

PEFT Benefits:

1. Storage Efficiency: Full model takes ~2.4 GB while LoRA adapter is only ~15 MB - 99.4% savings.

2. Multiple Tasks: Share one base model (2.4 GB) with multiple small adapters (15 MB each). Three tasks total 2.4 GB + 45 MB vs 7.2 GB for full models.

3. Training Speed: Fewer parameters means faster training and lower memory enables larger batches.

4. Quality: Often matches full fine-tuning, sometimes slightly lower for very complex tasks.


Adapter Layers

Adapter Implementation

Adapter layers insert small bottleneck modules into the transformer architecture:

🐍python
1class AdapterLayer(nn.Module):
2    """
3    Adapter layer as proposed by Houlsby et al. (2019).
4
5    Structure:
6    Input → LayerNorm → Down-project → Activation → Up-project → Residual → Output
7    """
8
9    def __init__(
10        self,
11        hidden_size: int,
12        adapter_size: int = 64,
13        activation: str = "gelu"
14    ):
15        super().__init__()
16
17        self.layer_norm = nn.LayerNorm(hidden_size)
18
19        # Down-projection
20        self.down_project = nn.Linear(hidden_size, adapter_size)
21
22        # Activation
23        if activation == "gelu":
24            self.activation = nn.GELU()
25        elif activation == "relu":
26            self.activation = nn.ReLU()
27        else:
28            self.activation = nn.GELU()
29
30        # Up-projection
31        self.up_project = nn.Linear(adapter_size, hidden_size)
32
33        # Initialize to near-identity
34        nn.init.normal_(self.down_project.weight, std=0.01)
35        nn.init.zeros_(self.down_project.bias)
36        nn.init.normal_(self.up_project.weight, std=0.01)
37        nn.init.zeros_(self.up_project.bias)
38
39    def forward(self, x: torch.Tensor) -> torch.Tensor:
40        """
41        Forward pass with residual connection.
42        """
43        residual = x
44
45        # Adapter pathway
46        x = self.layer_norm(x)
47        x = self.down_project(x)
48        x = self.activation(x)
49        x = self.up_project(x)
50
51        # Residual connection
52        return residual + x

Adapters are inserted after attention and FFN layers in each transformer block. The original model parameters are frozen while only the adapter parameters are trained.

ConfigAdapter SizeParams/AdapterTotal Params
Tiny1633,8081,622,784
Small3266,5923,196,416
Medium64132,1606,343,680
Large128263,29612,638,208
XL256525,56825,227,264

For mBART with 24 layers (12 encoder + 12 decoder) and hidden size 1024, even the "Large" adapter configuration only adds ~12.6 million trainable parameters compared to the base 610 million.


Data Augmentation for Translation

Back-Translation and Paraphrasing

Data augmentation can significantly improve translation quality, especially with limited training data:

1. Back-Translation: Translate target sentences back to source language to create synthetic training pairs. For example, start with English: "The dog runs quickly through the park." Translate EN → DE: "Der Hund rennt schnell durch den Park." Use as synthetic pair: (DE synthetic, EN original). Benefits include diverse source variations, fluent targets, effectively 2x-3x more training data, and typical improvement of +2-4 BLEU.

2. Word Dropout: Randomly drop words to simulate missing information and create robustness to incomplete input.

3. Word Shuffle: Randomly shuffle words within a limited distance to simulate word order variations.

4. Random Insertion: Insert random words from the vocabulary to add noise.

🐍python
1class TranslationDataAugmenter:
2    """
3    Data augmentation techniques for translation.
4    """
5
6    def __init__(
7        self,
8        forward_model=None,  # Source → Target model
9        backward_model=None,  # Target → Source model
10        device: str = "cuda"
11    ):
12        self.forward_model = forward_model
13        self.backward_model = backward_model
14        self.device = device
15
16    def back_translation(
17        self,
18        target_sentences: List[str],
19        num_samples: int = 1
20    ) -> List[Tuple[str, str]]:
21        """
22        Generate synthetic source sentences via back-translation.
23
24        Target → Synthetic Source → Target
25        """
26        augmented = []
27
28        for target in target_sentences:
29            for _ in range(num_samples):
30                # Translate target → source (with some randomness)
31                synthetic_source = self.backward_model.translate(
32                    target,
33                    num_beams=1,
34                    do_sample=True,
35                    temperature=1.0
36                )
37                augmented.append((synthetic_source, target))
38
39        return augmented
40
41    @staticmethod
42    def word_dropout(
43        sentence: str,
44        dropout_prob: float = 0.1
45    ) -> str:
46        """
47        Randomly drop words from sentence.
48        """
49        import random
50        words = sentence.split()
51        kept_words = [w for w in words if random.random() > dropout_prob]
52
53        if len(kept_words) == 0:
54            return sentence  # Don't return empty
55
56        return ' '.join(kept_words)
57
58    @staticmethod
59    def word_shuffle(
60        sentence: str,
61        max_distance: int = 3
62    ) -> str:
63        """
64        Randomly shuffle words within a distance.
65        """
66        import random
67        words = sentence.split()
68
69        # Add noise to positions
70        positions = list(range(len(words)))
71        noisy_positions = [
72            p + random.uniform(-max_distance, max_distance)
73            for p in positions
74        ]
75
76        # Sort by noisy positions
77        shuffled = [
78            word for _, word in sorted(zip(noisy_positions, words))
79        ]
80
81        return ' '.join(shuffled)

Augmentation Strategies: For the source language (German), use word dropout to simulate missing info, word shuffle for word order variations, and back-translation for diverse paraphrases. For the target language (English), usually keep unchanged to maintain fluent output, but can use forward-translation of augmented source.

Multi30k Augmentation Example: Original dataset has 29,000 pairs. Adding back-translation (2x) gives 58,000 pairs. Adding source augmentation (2x more) gives 116,000 pairs. Expected improvement: +3-5 BLEU.


Multi-Task Fine-tuning

Training on Multiple Language Pairs

Instead of training separate models for each translation direction, train one model on all tasks simultaneously:

Benefits:

1. Knowledge Transfer: Learning DE→EN helps FR→EN through shared representations across languages.

2. Better Low-Resource: High-resource pairs help low-resource pairs, making the model more robust to data scarcity.

3. Efficiency: One model vs N models, with shared inference infrastructure.

🐍python
1class MultiTaskTranslationDataset(Dataset):
2    """
3    Dataset for multi-task translation fine-tuning.
4
5    Supports multiple language pairs with task prefixes.
6    """
7
8    def __init__(
9        self,
10        data_files: Dict[str, Tuple[str, str]],
11        tokenizer,
12        src_langs: Dict[str, str],
13        tgt_langs: Dict[str, str],
14        max_length: int = 128,
15        use_task_prefix: bool = False
16    ):
17        self.tokenizer = tokenizer
18        self.max_length = max_length
19        self.use_task_prefix = use_task_prefix
20
21        self.examples = []
22
23        for lang_pair, (src_file, tgt_file) in data_files.items():
24            src_lang = src_langs[lang_pair]
25            tgt_lang = tgt_langs[lang_pair]
26
27            with open(src_file, 'r', encoding='utf-8') as f:
28                sources = [line.strip() for line in f]
29            with open(tgt_file, 'r', encoding='utf-8') as f:
30                targets = [line.strip() for line in f]
31
32            for src, tgt in zip(sources, targets):
33                # Optionally add task prefix
34                if use_task_prefix:
35                    src = f"translate {src_lang} to {tgt_lang}: {src}"
36
37                self.examples.append({
38                    'source': src,
39                    'target': tgt,
40                    'src_lang': src_lang,
41                    'tgt_lang': tgt_lang
42                })
43
44        print(f"Loaded {len(self.examples)} examples from {len(data_files)} language pairs")

Task Identification Methods:

Method 1 (mBART style): Use language codes. Source: "Der Hund läuft." + src_lang="de_DE". Target starts with forced language token: "<en_XX>".

Method 2 (T5 style): Use task prefix. Source: "translate German to English: Der Hund läuft.". Target: "The dog runs."

🐍python
1# Multi-task data configuration
2data_config = {
3    "de-en": {
4        "train_src": "./data/multi30k/train.de",
5        "train_tgt": "./data/multi30k/train.en",
6        "src_lang": "de_DE",
7        "tgt_lang": "en_XX"
8    },
9    "en-de": {  # Reverse direction
10        "train_src": "./data/multi30k/train.en",
11        "train_tgt": "./data/multi30k/train.de",
12        "src_lang": "en_XX",
13        "tgt_lang": "de_DE"
14    },
15    "fr-en": {  # Add more languages if available
16        "train_src": "./data/wmt14/train.fr",
17        "train_tgt": "./data/wmt14/train.en",
18        "src_lang": "fr_XX",
19        "tgt_lang": "en_XX"
20    }
21}
22
23# Training with temperature sampling
24# Higher temperature for low-resource pairs
25sampling_temperatures = {
26    "de-en": 1.0,   # Normal sampling
27    "en-de": 1.0,
28    "fr-en": 0.5,   # Upsample if less data
29}

Summary

MethodTrainable %StorageQualityBest For
Full Fine-tuning100%2.4 GBBestSingle task, high quality
LoRA0.5-2%15-50 MBNear-bestMultiple tasks, efficient
Adapters2-5%50-100 MBGoodModular, interpretable
Prefix Tuning0.1%2-5 MBDecentVery limited resources

Best Practices:

1. Start with full fine-tuning as baseline

2. Try LoRA with r=8 or r=16

3. Use data augmentation for small datasets

4. Consider multi-task learning for multiple directions

5. Monitor for overfitting on small data


Exercises

1. Implement LoRA from scratch for a single linear layer.

2. Apply LoRA to mBART using PEFT library and compare to full fine-tuning.

3. Implement back-translation augmentation and measure improvement.

4. Train a multi-task model on DE↔EN (both directions).

5. Compare adapter sizes (32, 64, 128) on translation quality.

Loading comments...