Chapter 17
20 min read
Section 74 of 75

Model Optimization for Production

Production Deployment

Introduction

This section covers techniques to optimize transformer models for production deployment, including quantization, pruning, and knowledge distillation.


Why Optimization Matters

Production Constraints

🐍python
1import torch
2import torch.nn as nn
3from typing import Dict, List, Optional, Tuple
4import time
5
6
7def production_requirements():
8    """
9    Overview of production requirements.
10    """
11    print("=" * 70)
12    print("PRODUCTION DEPLOYMENT REQUIREMENTS")
13    print("=" * 70)
14
15    print("""
16    TYPICAL PRODUCTION CONSTRAINTS:
17    ───────────────────────────────
18
19    1. LATENCY
20       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
21       β”‚  User expectation: < 100ms response for interactive apps   β”‚
22       β”‚  Translation service: < 500ms per sentence                 β”‚
23       β”‚  Batch processing: Throughput > cost                       β”‚
24       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
25
26    2. MEMORY
27       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
28       β”‚  Cloud GPU: 16-80GB, expensive ($1-30/hour)               β”‚
29       β”‚  Edge device: 4-16GB, limited                              β”‚
30       β”‚  Mobile: < 4GB, very constrained                           β”‚
31       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
32
33    3. COST
34       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
35       β”‚  Compute: $/1M tokens (API pricing model)                  β”‚
36       β”‚  Infrastructure: GPU hours Γ— $/hour                        β”‚
37       β”‚  Scaling: Cost per user Γ— number of users                  β”‚
38       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
39
40    4. THROUGHPUT
41       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
42       β”‚  Requests per second per GPU                               β”‚
43       β”‚  Concurrent users supported                                 β”‚
44       β”‚  Batch efficiency                                          β”‚
45       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
46
47
48    OUR TRANSLATION MODEL:
49    ──────────────────────
50
51    Baseline (FP32, unoptimized):
52    β”œβ”€β”€ Model size: ~260 MB (65M params Γ— 4 bytes)
53    β”œβ”€β”€ Memory usage: ~1.5 GB (with activations)
54    β”œβ”€β”€ Latency: ~200ms per sentence
55    └── Throughput: ~5 sentences/second
56
57    Target (optimized):
58    β”œβ”€β”€ Model size: ~65 MB (INT8)
59    β”œβ”€β”€ Memory usage: ~400 MB
60    β”œβ”€β”€ Latency: ~50ms per sentence
61    └── Throughput: ~20 sentences/second
62
63    4x improvement possible!
64
65
66    OPTIMIZATION TECHNIQUES:
67    ────────────────────────
68
69    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
70    β”‚ Technique           β”‚ Speedup β”‚ Memory β”‚ Quality β”‚ Complexity     β”‚
71    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
72    β”‚ FP16 (half prec.)   β”‚ 1.5-2x  β”‚ 50%    β”‚ Same    β”‚ Low            β”‚
73    β”‚ INT8 Quantization   β”‚ 2-4x    β”‚ 75%    β”‚ -0.5%   β”‚ Medium         β”‚
74    β”‚ INT4 Quantization   β”‚ 4-8x    β”‚ 87%    β”‚ -2%     β”‚ Medium         β”‚
75    β”‚ Pruning             β”‚ 1.5-3x  β”‚ 50-90% β”‚ -1-5%   β”‚ High           β”‚
76    β”‚ Distillation        β”‚ 2-10x   β”‚ 50-90% β”‚ -2-5%   β”‚ High           β”‚
77    β”‚ ONNX Export         β”‚ 1.2-2x  β”‚ Same   β”‚ Same    β”‚ Low            β”‚
78    β”‚ TensorRT            β”‚ 2-5x    β”‚ ~Same  β”‚ Same    β”‚ Medium         β”‚
79    β”‚ Flash Attention     β”‚ 2-4x    β”‚ 50-90% β”‚ Same    β”‚ Low            β”‚
80    β”‚ KV Caching          β”‚ 2-10x   β”‚ +Memoryβ”‚ Same    β”‚ Low            β”‚
81    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
82    """)
83
84
85production_requirements()

Quantization

Reducing Precision for Faster Inference

🐍python
1class QuantizationExplained:
2    """
3    Quantization techniques for model optimization.
4    """
5
6    def __init__(self):
7        pass
8
9    def demonstrate_quantization_concepts(self):
10        """Show quantization basics."""
11        print("=" * 60)
12        print("QUANTIZATION BASICS")
13        print("=" * 60)
14
15        print("""
16    WHAT IS QUANTIZATION?
17    ─────────────────────
18
19    Converting floating point weights to lower precision integers.
20
21    FP32 (32-bit float):
22    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
23    β”‚ Sign β”‚      Exponent (8 bits)      β”‚  Mantissa (23 bits)   β”‚
24    β”‚  1   β”‚ 0 1 0 0 0 0 0 0             β”‚ 1 1 0 0 0 0 0 0 ...   β”‚
25    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
26    Range: Β±3.4 Γ— 10³⁸, Precision: ~7 decimal digits
27
28    INT8 (8-bit integer):
29    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
30    β”‚  0  1  0  1  0  1  1  0     β”‚   = 86
31    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
32    Range: -128 to 127, Precision: 1
33
34
35    THE MAPPING:
36    ────────────
37
38    Scale = (max - min) / 255
39    Zero_point = round(-min / scale)
40
41    Quantize: q = round(x / scale) + zero_point
42    Dequantize: x = (q - zero_point) Γ— scale
43
44    Example:
45    FP32 weights: [-0.5, 0.2, 0.8, -0.3, 0.5]
46    Range: -0.5 to 0.8
47    Scale = (0.8 - (-0.5)) / 255 = 0.0051
48    Zero_point = round(0.5 / 0.0051) = 98
49
50    Quantized (INT8): [0, 137, 255, 39, 196]
51
52
53    QUANTIZATION TYPES:
54    ───────────────────
55
56    1. DYNAMIC QUANTIZATION
57       - Weights quantized at load time
58       - Activations quantized on-the-fly
59       - Easy to apply, good results
60       - Best for: CPU inference, memory-bound
61
62    2. STATIC QUANTIZATION
63       - Both weights and activations pre-quantized
64       - Requires calibration data
65       - Faster inference
66       - Best for: Edge devices, mobile
67
68    3. QUANTIZATION-AWARE TRAINING (QAT)
69       - Simulate quantization during training
70       - Model learns to be robust to quantization
71       - Best quality
72       - Best for: When quality is critical
73        """)
74
75    def pytorch_quantization_example(self):
76        """Show PyTorch quantization."""
77        print("\nPYTORCH QUANTIZATION")
78        print("=" * 60)
79
80        code = '''
81import torch
82from torch.quantization import quantize_dynamic, quantize_static
83
84# ===================
85# DYNAMIC QUANTIZATION
86# ===================
87
88# Original model
89model = TransformerModel()
90model.load_state_dict(torch.load("model.pt"))
91model.eval()
92
93# Quantize (just one line!)
94quantized_model = quantize_dynamic(
95    model,
96    {torch.nn.Linear, torch.nn.Embedding},  # Layers to quantize
97    dtype=torch.qint8
98)
99
100# Compare sizes
101original_size = sum(p.numel() * 4 for p in model.parameters()) / 1e6
102quantized_size = sum(p.numel() for p in quantized_model.parameters()) / 1e6
103print(f"Original: {original_size:.1f} MB")
104print(f"Quantized: {quantized_size:.1f} MB")
105
106# Inference
107with torch.no_grad():
108    output = quantized_model(input_ids)
109
110
111# ===================
112# STATIC QUANTIZATION
113# ===================
114
115# Prepare model with observers
116model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
117model_prepared = torch.quantization.prepare(model)
118
119# Calibrate with representative data
120with torch.no_grad():
121    for batch in calibration_loader:
122        model_prepared(batch)
123
124# Convert to quantized model
125model_quantized = torch.quantization.convert(model_prepared)
126
127
128# =================================
129# QUANTIZATION-AWARE TRAINING (QAT)
130# =================================
131
132# Prepare model for QAT
133model.train()
134model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
135model_qat = torch.quantization.prepare_qat(model)
136
137# Train with fake quantization
138for epoch in range(num_epochs):
139    for batch in train_loader:
140        output = model_qat(batch)
141        loss = criterion(output, target)
142        loss.backward()
143        optimizer.step()
144
145# Convert to quantized model
146model_qat.eval()
147model_quantized = torch.quantization.convert(model_qat)
148'''
149        print(code)
150
151
152def demonstrate_quantization():
153    """Demonstrate quantization impact."""
154    print("\nQuantization Impact Demonstration")
155    print("=" * 60)
156
157    # Create a simple linear layer
158    d_model = 512
159    layer = nn.Linear(d_model, d_model)
160
161    # Original weights
162    original_weights = layer.weight.data.clone()
163
164    # Simulate INT8 quantization
165    w_min = original_weights.min()
166    w_max = original_weights.max()
167    scale = (w_max - w_min) / 255
168    zero_point = torch.round(-w_min / scale).int()
169
170    # Quantize
171    w_quantized = torch.round(original_weights / scale) + zero_point
172    w_quantized = w_quantized.clamp(0, 255).byte()
173
174    # Dequantize
175    w_dequantized = (w_quantized.float() - zero_point) * scale
176
177    # Compute error
178    error = (original_weights - w_dequantized).abs()
179    relative_error = error / (original_weights.abs() + 1e-8)
180
181    print(f"Weight statistics:")
182    print(f"  Range: [{w_min:.4f}, {w_max:.4f}]")
183    print(f"  Scale: {scale:.6f}")
184    print(f"  Zero point: {zero_point}")
185    print(f"\nQuantization error:")
186    print(f"  Max absolute error: {error.max():.6f}")
187    print(f"  Mean absolute error: {error.mean():.6f}")
188    print(f"  Mean relative error: {relative_error.mean()*100:.2f}%")
189    print(f"\nMemory:")
190    print(f"  Original (FP32): {original_weights.numel() * 4:,} bytes")
191    print(f"  Quantized (INT8): {w_quantized.numel():,} bytes")
192    print(f"  Compression: {4:.0f}x")
193
194
195demonstrate_quantization()

Knowledge Distillation

Training Smaller Models from Larger Ones

🐍python
1class KnowledgeDistillation:
2    """
3    Knowledge distillation for model compression.
4    """
5
6    def __init__(self):
7        pass
8
9    def explain_distillation(self):
10        """Explain distillation concept."""
11        print("=" * 60)
12        print("KNOWLEDGE DISTILLATION")
13        print("=" * 60)
14
15        print("""
16    THE IDEA:
17    ─────────
18
19    Train a small "student" model to mimic a large "teacher" model.
20
21    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
22    β”‚                                                                 β”‚
23    β”‚  Input ──────┬──────────────────────────────────────────────    β”‚
24    β”‚              β”‚                                                  β”‚
25    β”‚              β–Ό                                                  β”‚
26    β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                         β”‚
27    β”‚     β”‚    Teacher     β”‚  (Large, trained)                       β”‚
28    β”‚     β”‚   610M params  β”‚                                         β”‚
29    β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜                                         β”‚
30    β”‚             β”‚                                                   β”‚
31    β”‚             β–Ό                                                   β”‚
32    β”‚      Soft Labels (probabilities over vocab)                    β”‚
33    β”‚             β”‚                                                   β”‚
34    β”‚             β”‚  KL Divergence Loss                              β”‚
35    β”‚             β”‚                                                   β”‚
36    β”‚             β–Ό                                                   β”‚
37    β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                         β”‚
38    β”‚     β”‚    Student     β”‚  (Small, training)                      β”‚
39    β”‚     β”‚   65M params   β”‚                                         β”‚
40    β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜                                         β”‚
41    β”‚             β”‚                                                   β”‚
42    β”‚             β–Ό                                                   β”‚
43    β”‚     Student Predictions                                        β”‚
44    β”‚                                                                 β”‚
45    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
46
47
48    WHY IT WORKS:
49    ─────────────
50
51    Teacher's soft labels contain more information than hard labels:
52
53    Hard label:  [0, 0, 1, 0, 0]  ("dog")
54
55    Soft label:  [0.01, 0.05, 0.80, 0.12, 0.02]
56                  cat   bird  dog  wolf  fox
57
58    The soft labels show:
59    - "dog" is the answer (0.80)
60    - "wolf" is somewhat similar (0.12)
61    - "cat" is a bit related (0.01)
62
63    This "dark knowledge" helps the student learn better!
64
65
66    LOSS FUNCTION:
67    ──────────────
68
69    L = Ξ± Γ— L_distill + (1-Ξ±) Γ— L_task
70
71    Where:
72    - L_distill = KL(teacher_soft || student_soft) at temperature T
73    - L_task = CrossEntropy(student, hard_labels)
74    - Ξ± = weight (typically 0.5-0.9)
75    - T = temperature (typically 2-10)
76
77    Temperature T softens the distributions:
78    - T=1: Original distribution
79    - T>1: Softer (more uniform)
80    - Higher T reveals more relationships
81
82
83    DISTILLATION FOR TRANSLATION:
84    ─────────────────────────────
85
86    Teacher: mBART-large (610M params, BLEU ~45)
87    Student: Our model (65M params)
88
89    Options:
90    1. Word-level distillation (soft labels per token)
91    2. Sequence-level distillation (translate with teacher, train student)
92    3. Hidden-state distillation (match intermediate representations)
93
94    Expected results:
95    - Without distillation: BLEU ~30-35
96    - With distillation: BLEU ~38-42
97
98    9-10x smaller, only 3-7 BLEU drop!
99        """)
100
101
102class DistillationTrainer:
103    """
104    Trainer for knowledge distillation.
105    """
106
107    def __init__(
108        self,
109        teacher_model: nn.Module,
110        student_model: nn.Module,
111        temperature: float = 4.0,
112        alpha: float = 0.7
113    ):
114        """
115        Initialize distillation trainer.
116
117        Args:
118            teacher_model: Large pre-trained model
119            student_model: Small model to train
120            temperature: Softmax temperature
121            alpha: Weight for distillation loss
122        """
123        self.teacher = teacher_model
124        self.student = student_model
125        self.temperature = temperature
126        self.alpha = alpha
127
128        # Freeze teacher
129        for param in self.teacher.parameters():
130            param.requires_grad = False
131        self.teacher.eval()
132
133    def distillation_loss(
134        self,
135        student_logits: torch.Tensor,
136        teacher_logits: torch.Tensor,
137        labels: torch.Tensor
138    ) -> Tuple[torch.Tensor, Dict[str, float]]:
139        """
140        Compute distillation loss.
141
142        Args:
143            student_logits: Student output logits
144            teacher_logits: Teacher output logits
145            labels: Ground truth labels
146
147        Returns:
148            Total loss, loss components dict
149        """
150        # Soft labels with temperature
151        T = self.temperature
152
153        student_soft = torch.log_softmax(student_logits / T, dim=-1)
154        teacher_soft = torch.softmax(teacher_logits / T, dim=-1)
155
156        # KL divergence loss (distillation)
157        distill_loss = nn.functional.kl_div(
158            student_soft,
159            teacher_soft,
160            reduction='batchmean'
161        ) * (T * T)  # Scale by T^2 as per Hinton et al.
162
163        # Hard label loss (task)
164        task_loss = nn.functional.cross_entropy(
165            student_logits.view(-1, student_logits.size(-1)),
166            labels.view(-1),
167            ignore_index=-100
168        )
169
170        # Combined loss
171        total_loss = self.alpha * distill_loss + (1 - self.alpha) * task_loss
172
173        return total_loss, {
174            'distill_loss': distill_loss.item(),
175            'task_loss': task_loss.item(),
176            'total_loss': total_loss.item()
177        }
178
179    def train_step(
180        self,
181        batch: Dict[str, torch.Tensor]
182    ) -> Dict[str, float]:
183        """
184        Single training step with distillation.
185
186        Args:
187            batch: Input batch
188
189        Returns:
190            Loss values
191        """
192        self.student.train()
193
194        # Get teacher predictions (no grad)
195        with torch.no_grad():
196            teacher_outputs = self.teacher(**batch)
197            teacher_logits = teacher_outputs.logits
198
199        # Get student predictions
200        student_outputs = self.student(**batch)
201        student_logits = student_outputs.logits
202
203        # Compute loss
204        loss, loss_dict = self.distillation_loss(
205            student_logits,
206            teacher_logits,
207            batch['labels']
208        )
209
210        return loss, loss_dict
211
212
213def distillation_code_example():
214    """Show distillation code."""
215    print("\nDistillation Code Example")
216    print("=" * 60)
217
218    code = '''
219# Complete distillation training loop
220
221from transformers import MBartForConditionalGeneration
222from your_model import TransformerModel
223
224# Load teacher (pre-trained mBART)
225teacher = MBartForConditionalGeneration.from_pretrained(
226    "facebook/mbart-large-50-many-to-many-mmt"
227)
228teacher.eval()
229teacher.cuda()
230
231# Initialize student (your smaller model)
232student = TransformerModel(
233    vocab_size=32000,
234    d_model=256,      # Smaller than teacher
235    num_heads=4,
236    num_layers=4,     # Fewer layers
237    d_ff=512
238)
239student.cuda()
240
241# Distillation trainer
242trainer = DistillationTrainer(
243    teacher_model=teacher,
244    student_model=student,
245    temperature=4.0,
246    alpha=0.7
247)
248
249optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)
250
251# Training loop
252for epoch in range(10):
253    for batch in train_loader:
254        batch = {k: v.cuda() for k, v in batch.items()}
255
256        loss, loss_dict = trainer.train_step(batch)
257
258        optimizer.zero_grad()
259        loss.backward()
260        optimizer.step()
261
262        print(f"Distill: {loss_dict['distill_loss']:.4f}, "
263              f"Task: {loss_dict['task_loss']:.4f}")
264
265# Save student
266torch.save(student.state_dict(), "student_model.pt")
267'''
268    print(code)
269
270
271distillation_code_example()

Pruning

Removing Unnecessary Parameters

🐍python
1def pruning_explained():
2    """Explain pruning techniques."""
3    print("=" * 60)
4    print("MODEL PRUNING")
5    print("=" * 60)
6
7    print("""
8    WHAT IS PRUNING?
9    ────────────────
10
11    Remove weights that contribute little to the output.
12
13    Before pruning:
14    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
15    β”‚  0.12  0.85  0.03  0.91  0.02           β”‚
16    β”‚  0.78  0.01  0.65  0.04  0.89           β”‚
17    β”‚  0.02  0.94  0.01  0.87  0.03           β”‚
18    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
19
20    After pruning (50%):
21    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
22    β”‚  0.00  0.85  0.00  0.91  0.00           β”‚
23    β”‚  0.78  0.00  0.65  0.00  0.89           β”‚
24    β”‚  0.00  0.94  0.00  0.87  0.00           β”‚
25    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
26
27
28    PRUNING STRATEGIES:
29    ───────────────────
30
31    1. MAGNITUDE PRUNING
32       Remove weights with smallest absolute value.
33       Simple and effective.
34
35    2. STRUCTURED PRUNING
36       Remove entire neurons, attention heads, or layers.
37       Actually reduces computation (not just memory).
38
39    3. MOVEMENT PRUNING
40       Remove weights that move toward zero during training.
41       Better for fine-tuning.
42
43
44    UNSTRUCTURED VS STRUCTURED:
45    ───────────────────────────
46
47    Unstructured (sparse):
48    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
49    β”‚ x 0 x 0 0 x 0 x    β”‚  Random zeros
50    β”‚ 0 x 0 x x 0 x 0    β”‚  Memory savings only
51    β”‚ x 0 0 x 0 x 0 x    β”‚  Need special hardware
52    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
53
54    Structured (dense):
55    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
56    β”‚ x x x x x x x x    β”‚  Remove entire rows/columns
57    β”‚ 0 0 0 0 0 0 0 0    β”‚  Actual speedup on any hardware
58    β”‚ x x x x x x x x    β”‚  Smaller model
59    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
60
61
62    PRUNING ATTENTION HEADS:
63    ────────────────────────
64
65    Not all attention heads are equally important!
66
67    Head importance analysis:
68    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
69    β”‚ Layer 1: [0.85, 0.12, 0.76, 0.05, ...]  β”‚
70    β”‚ Layer 2: [0.92, 0.88, 0.03, 0.79, ...]  β”‚
71    β”‚ Layer 3: [0.45, 0.67, 0.89, 0.21, ...]  β”‚
72    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
73
74    Can remove heads with low importance scores
75    with minimal quality loss!
76    """)
77
78
79class MagnitudePruner:
80    """
81    Simple magnitude-based pruning.
82    """
83
84    def __init__(
85        self,
86        model: nn.Module,
87        sparsity: float = 0.5
88    ):
89        """
90        Initialize pruner.
91
92        Args:
93            model: Model to prune
94            sparsity: Target sparsity (0.5 = 50% zeros)
95        """
96        self.model = model
97        self.sparsity = sparsity
98
99    def compute_threshold(self) -> float:
100        """Compute global magnitude threshold."""
101        all_weights = []
102        for name, param in self.model.named_parameters():
103            if 'weight' in name and param.dim() >= 2:
104                all_weights.append(param.abs().flatten())
105
106        all_weights = torch.cat(all_weights)
107        threshold = torch.quantile(all_weights, self.sparsity)
108        return threshold.item()
109
110    def prune(self) -> Dict[str, float]:
111        """
112        Apply magnitude pruning.
113
114        Returns:
115            Pruning statistics
116        """
117        threshold = self.compute_threshold()
118
119        total_params = 0
120        pruned_params = 0
121
122        for name, param in self.model.named_parameters():
123            if 'weight' in name and param.dim() >= 2:
124                mask = param.abs() >= threshold
125                param.data *= mask.float()
126
127                total_params += param.numel()
128                pruned_params += (~mask).sum().item()
129
130        actual_sparsity = pruned_params / total_params
131
132        return {
133            'target_sparsity': self.sparsity,
134            'actual_sparsity': actual_sparsity,
135            'pruned_params': pruned_params,
136            'total_params': total_params
137        }
138
139
140class HeadPruner:
141    """
142    Prune attention heads based on importance.
143    """
144
145    def __init__(self, model: nn.Module):
146        self.model = model
147
148    def compute_head_importance(
149        self,
150        dataloader: torch.utils.data.DataLoader,
151        num_batches: int = 100
152    ) -> Dict[str, torch.Tensor]:
153        """
154        Compute importance scores for each attention head.
155
156        Uses gradient-based importance.
157        """
158        head_importance = {}
159
160        self.model.eval()
161
162        # Would iterate through layers and compute importance
163        # Simplified for demonstration
164
165        return head_importance
166
167    def prune_heads(
168        self,
169        heads_to_prune: Dict[int, List[int]]
170    ):
171        """
172        Prune specific heads.
173
174        Args:
175            heads_to_prune: {layer_idx: [head_indices]}
176        """
177        # Would modify model to remove heads
178        pass
179
180
181def demonstrate_pruning():
182    """Demonstrate pruning."""
183    print("\nPruning Demonstration")
184    print("=" * 60)
185
186    # Create a model layer
187    layer = nn.Linear(256, 256)
188    print(f"Original parameters: {layer.weight.numel():,}")
189    print(f"Original sparsity: {(layer.weight == 0).sum().item() / layer.weight.numel():.2%}")
190
191    # Apply pruning
192    model = nn.Sequential(layer)
193    pruner = MagnitudePruner(model, sparsity=0.5)
194    stats = pruner.prune()
195
196    print(f"\nAfter 50% pruning:")
197    print(f"  Target sparsity: {stats['target_sparsity']:.2%}")
198    print(f"  Actual sparsity: {stats['actual_sparsity']:.2%}")
199    print(f"  Pruned parameters: {stats['pruned_params']:,}")
200
201
202demonstrate_pruning()

Optimization Best Practices

Production Checklist

🐍python
1def optimization_best_practices():
2    """Summary of optimization best practices."""
3    print("=" * 70)
4    print("OPTIMIZATION BEST PRACTICES")
5    print("=" * 70)
6
7    print("""
8    RECOMMENDED OPTIMIZATION PIPELINE:
9    ──────────────────────────────────
10
11    1. START WITH EVALUATION
12       β–‘ Measure baseline latency
13       β–‘ Measure baseline memory
14       β–‘ Record quality metrics (BLEU)
15
16    2. EASY WINS (Do these first)
17       β–‘ Use FP16/BF16 inference
18       β–‘ Enable Flash Attention
19       β–‘ Implement KV caching
20       β–‘ Optimize batch size
21
22    3. QUANTIZATION (Most impactful)
23       β–‘ Try dynamic INT8 first
24       β–‘ If quality drops, try QAT
25       β–‘ Benchmark INT4 if size critical
26
27    4. MODEL ARCHITECTURE (If needed)
28       β–‘ Prune attention heads
29       β–‘ Reduce layers if acceptable
30       β–‘ Consider distillation
31
32    5. INFERENCE ENGINE (For production)
33       β–‘ Export to ONNX
34       β–‘ Use TensorRT (NVIDIA)
35       β–‘ Consider vLLM for serving
36
37
38    QUALITY VS SPEED TRADE-OFFS:
39    ────────────────────────────
40
41    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
42    β”‚  Optimization      β”‚  Speed  β”‚  Quality β”‚  Recommendation     β”‚
43    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
44    β”‚  FP32 β†’ FP16       β”‚  +50%   β”‚  ~Same   β”‚  Always do          β”‚
45    β”‚  FP16 β†’ INT8       β”‚  +100%  β”‚  -0.5%   β”‚  Usually do         β”‚
46    β”‚  INT8 β†’ INT4       β”‚  +50%   β”‚  -2%     β”‚  If size critical   β”‚
47    β”‚  6L β†’ 4L           β”‚  +50%   β”‚  -3%     β”‚  Test carefully     β”‚
48    β”‚  Distillation      β”‚  +200%  β”‚  -5%     β”‚  For major savings  β”‚
49    β”‚  KV Cache          β”‚  +300%  β”‚  Same    β”‚  Always for gen     β”‚
50    β”‚  Flash Attn        β”‚  +200%  β”‚  Same    β”‚  Always if possible β”‚
51    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
52
53
54    COMMON MISTAKES TO AVOID:
55    ─────────────────────────
56
57    βœ— Optimizing before measuring baseline
58    βœ— Applying all optimizations at once
59    βœ— Ignoring quality metrics
60    βœ— Not testing on representative data
61    βœ— Forgetting to benchmark throughput (not just latency)
62
63
64    MONITORING IN PRODUCTION:
65    ─────────────────────────
66
67    Track these metrics:
68    - P50, P95, P99 latency
69    - Throughput (requests/second)
70    - GPU utilization
71    - Memory usage
72    - Quality metrics (periodic)
73    - Error rates
74    """)
75
76
77optimization_best_practices()

Summary

Optimization Techniques Summary

TechniqueSpeedupMemoryQuality ImpactComplexity
FP161.5-2x50%NoneLow
INT82-4x75%<1%Medium
Distillation2-10x50-90%2-5%High
Pruning1.5-3x50-90%1-5%Medium
Flash Attention2-4x50-90%NoneLow

Quick Reference

Priority order for optimization:

  • FP16/BF16 (free speedup)
  • Flash Attention (if available)
  • KV Caching (for generation)
  • Dynamic INT8 quantization
  • Model pruning/distillation (if needed)
  • TensorRT/ONNX (for deployment)

Exercises

  • Apply dynamic quantization to your translation model and measure speedup.
  • Implement magnitude pruning and find the maximum sparsity with <1 BLEU drop.
  • Train a distilled 4-layer model from your 6-layer model.
  • Compare FP32, FP16, and INT8 inference latency.
  • Profile your model to identify the slowest operations.

In the next section, we'll cover model export (ONNX, TensorRT) and serving infrastructure.

Loading comments...