Introduction
This section covers techniques to optimize transformer models for production deployment, including quantization, pruning, and knowledge distillation.
Why Optimization Matters
Production Constraints
πpython
1import torch
2import torch.nn as nn
3from typing import Dict, List, Optional, Tuple
4import time
5
6
7def production_requirements():
8 """
9 Overview of production requirements.
10 """
11 print("=" * 70)
12 print("PRODUCTION DEPLOYMENT REQUIREMENTS")
13 print("=" * 70)
14
15 print("""
16 TYPICAL PRODUCTION CONSTRAINTS:
17 βββββββββββββββββββββββββββββββ
18
19 1. LATENCY
20 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
21 β User expectation: < 100ms response for interactive apps β
22 β Translation service: < 500ms per sentence β
23 β Batch processing: Throughput > cost β
24 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
25
26 2. MEMORY
27 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
28 β Cloud GPU: 16-80GB, expensive ($1-30/hour) β
29 β Edge device: 4-16GB, limited β
30 β Mobile: < 4GB, very constrained β
31 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
32
33 3. COST
34 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
35 β Compute: $/1M tokens (API pricing model) β
36 β Infrastructure: GPU hours Γ $/hour β
37 β Scaling: Cost per user Γ number of users β
38 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
39
40 4. THROUGHPUT
41 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
42 β Requests per second per GPU β
43 β Concurrent users supported β
44 β Batch efficiency β
45 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
46
47
48 OUR TRANSLATION MODEL:
49 ββββββββββββββββββββββ
50
51 Baseline (FP32, unoptimized):
52 βββ Model size: ~260 MB (65M params Γ 4 bytes)
53 βββ Memory usage: ~1.5 GB (with activations)
54 βββ Latency: ~200ms per sentence
55 βββ Throughput: ~5 sentences/second
56
57 Target (optimized):
58 βββ Model size: ~65 MB (INT8)
59 βββ Memory usage: ~400 MB
60 βββ Latency: ~50ms per sentence
61 βββ Throughput: ~20 sentences/second
62
63 4x improvement possible!
64
65
66 OPTIMIZATION TECHNIQUES:
67 ββββββββββββββββββββββββ
68
69 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
70 β Technique β Speedup β Memory β Quality β Complexity β
71 βββββββββββββββββββββββΌββββββββββΌβββββββββΌββββββββββΌβββββββββββββββββ€
72 β FP16 (half prec.) β 1.5-2x β 50% β Same β Low β
73 β INT8 Quantization β 2-4x β 75% β -0.5% β Medium β
74 β INT4 Quantization β 4-8x β 87% β -2% β Medium β
75 β Pruning β 1.5-3x β 50-90% β -1-5% β High β
76 β Distillation β 2-10x β 50-90% β -2-5% β High β
77 β ONNX Export β 1.2-2x β Same β Same β Low β
78 β TensorRT β 2-5x β ~Same β Same β Medium β
79 β Flash Attention β 2-4x β 50-90% β Same β Low β
80 β KV Caching β 2-10x β +Memoryβ Same β Low β
81 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
82 """)
83
84
85production_requirements()Quantization
Reducing Precision for Faster Inference
πpython
1class QuantizationExplained:
2 """
3 Quantization techniques for model optimization.
4 """
5
6 def __init__(self):
7 pass
8
9 def demonstrate_quantization_concepts(self):
10 """Show quantization basics."""
11 print("=" * 60)
12 print("QUANTIZATION BASICS")
13 print("=" * 60)
14
15 print("""
16 WHAT IS QUANTIZATION?
17 βββββββββββββββββββββ
18
19 Converting floating point weights to lower precision integers.
20
21 FP32 (32-bit float):
22 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
23 β Sign β Exponent (8 bits) β Mantissa (23 bits) β
24 β 1 β 0 1 0 0 0 0 0 0 β 1 1 0 0 0 0 0 0 ... β
25 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
26 Range: Β±3.4 Γ 10Β³βΈ, Precision: ~7 decimal digits
27
28 INT8 (8-bit integer):
29 βββββββββββββββββββββββββββββββ
30 β 0 1 0 1 0 1 1 0 β = 86
31 βββββββββββββββββββββββββββββββ
32 Range: -128 to 127, Precision: 1
33
34
35 THE MAPPING:
36 ββββββββββββ
37
38 Scale = (max - min) / 255
39 Zero_point = round(-min / scale)
40
41 Quantize: q = round(x / scale) + zero_point
42 Dequantize: x = (q - zero_point) Γ scale
43
44 Example:
45 FP32 weights: [-0.5, 0.2, 0.8, -0.3, 0.5]
46 Range: -0.5 to 0.8
47 Scale = (0.8 - (-0.5)) / 255 = 0.0051
48 Zero_point = round(0.5 / 0.0051) = 98
49
50 Quantized (INT8): [0, 137, 255, 39, 196]
51
52
53 QUANTIZATION TYPES:
54 βββββββββββββββββββ
55
56 1. DYNAMIC QUANTIZATION
57 - Weights quantized at load time
58 - Activations quantized on-the-fly
59 - Easy to apply, good results
60 - Best for: CPU inference, memory-bound
61
62 2. STATIC QUANTIZATION
63 - Both weights and activations pre-quantized
64 - Requires calibration data
65 - Faster inference
66 - Best for: Edge devices, mobile
67
68 3. QUANTIZATION-AWARE TRAINING (QAT)
69 - Simulate quantization during training
70 - Model learns to be robust to quantization
71 - Best quality
72 - Best for: When quality is critical
73 """)
74
75 def pytorch_quantization_example(self):
76 """Show PyTorch quantization."""
77 print("\nPYTORCH QUANTIZATION")
78 print("=" * 60)
79
80 code = '''
81import torch
82from torch.quantization import quantize_dynamic, quantize_static
83
84# ===================
85# DYNAMIC QUANTIZATION
86# ===================
87
88# Original model
89model = TransformerModel()
90model.load_state_dict(torch.load("model.pt"))
91model.eval()
92
93# Quantize (just one line!)
94quantized_model = quantize_dynamic(
95 model,
96 {torch.nn.Linear, torch.nn.Embedding}, # Layers to quantize
97 dtype=torch.qint8
98)
99
100# Compare sizes
101original_size = sum(p.numel() * 4 for p in model.parameters()) / 1e6
102quantized_size = sum(p.numel() for p in quantized_model.parameters()) / 1e6
103print(f"Original: {original_size:.1f} MB")
104print(f"Quantized: {quantized_size:.1f} MB")
105
106# Inference
107with torch.no_grad():
108 output = quantized_model(input_ids)
109
110
111# ===================
112# STATIC QUANTIZATION
113# ===================
114
115# Prepare model with observers
116model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
117model_prepared = torch.quantization.prepare(model)
118
119# Calibrate with representative data
120with torch.no_grad():
121 for batch in calibration_loader:
122 model_prepared(batch)
123
124# Convert to quantized model
125model_quantized = torch.quantization.convert(model_prepared)
126
127
128# =================================
129# QUANTIZATION-AWARE TRAINING (QAT)
130# =================================
131
132# Prepare model for QAT
133model.train()
134model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
135model_qat = torch.quantization.prepare_qat(model)
136
137# Train with fake quantization
138for epoch in range(num_epochs):
139 for batch in train_loader:
140 output = model_qat(batch)
141 loss = criterion(output, target)
142 loss.backward()
143 optimizer.step()
144
145# Convert to quantized model
146model_qat.eval()
147model_quantized = torch.quantization.convert(model_qat)
148'''
149 print(code)
150
151
152def demonstrate_quantization():
153 """Demonstrate quantization impact."""
154 print("\nQuantization Impact Demonstration")
155 print("=" * 60)
156
157 # Create a simple linear layer
158 d_model = 512
159 layer = nn.Linear(d_model, d_model)
160
161 # Original weights
162 original_weights = layer.weight.data.clone()
163
164 # Simulate INT8 quantization
165 w_min = original_weights.min()
166 w_max = original_weights.max()
167 scale = (w_max - w_min) / 255
168 zero_point = torch.round(-w_min / scale).int()
169
170 # Quantize
171 w_quantized = torch.round(original_weights / scale) + zero_point
172 w_quantized = w_quantized.clamp(0, 255).byte()
173
174 # Dequantize
175 w_dequantized = (w_quantized.float() - zero_point) * scale
176
177 # Compute error
178 error = (original_weights - w_dequantized).abs()
179 relative_error = error / (original_weights.abs() + 1e-8)
180
181 print(f"Weight statistics:")
182 print(f" Range: [{w_min:.4f}, {w_max:.4f}]")
183 print(f" Scale: {scale:.6f}")
184 print(f" Zero point: {zero_point}")
185 print(f"\nQuantization error:")
186 print(f" Max absolute error: {error.max():.6f}")
187 print(f" Mean absolute error: {error.mean():.6f}")
188 print(f" Mean relative error: {relative_error.mean()*100:.2f}%")
189 print(f"\nMemory:")
190 print(f" Original (FP32): {original_weights.numel() * 4:,} bytes")
191 print(f" Quantized (INT8): {w_quantized.numel():,} bytes")
192 print(f" Compression: {4:.0f}x")
193
194
195demonstrate_quantization()Knowledge Distillation
Training Smaller Models from Larger Ones
πpython
1class KnowledgeDistillation:
2 """
3 Knowledge distillation for model compression.
4 """
5
6 def __init__(self):
7 pass
8
9 def explain_distillation(self):
10 """Explain distillation concept."""
11 print("=" * 60)
12 print("KNOWLEDGE DISTILLATION")
13 print("=" * 60)
14
15 print("""
16 THE IDEA:
17 βββββββββ
18
19 Train a small "student" model to mimic a large "teacher" model.
20
21 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
22 β β
23 β Input βββββββ¬ββββββββββββββββββββββββββββββββββββββββββββββ β
24 β β β
25 β βΌ β
26 β ββββββββββββββββββ β
27 β β Teacher β (Large, trained) β
28 β β 610M params β β
29 β βββββββββ¬βββββββββ β
30 β β β
31 β βΌ β
32 β Soft Labels (probabilities over vocab) β
33 β β β
34 β β KL Divergence Loss β
35 β β β
36 β βΌ β
37 β ββββββββββββββββββ β
38 β β Student β (Small, training) β
39 β β 65M params β β
40 β βββββββββ¬βββββββββ β
41 β β β
42 β βΌ β
43 β Student Predictions β
44 β β
45 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
46
47
48 WHY IT WORKS:
49 βββββββββββββ
50
51 Teacher's soft labels contain more information than hard labels:
52
53 Hard label: [0, 0, 1, 0, 0] ("dog")
54
55 Soft label: [0.01, 0.05, 0.80, 0.12, 0.02]
56 cat bird dog wolf fox
57
58 The soft labels show:
59 - "dog" is the answer (0.80)
60 - "wolf" is somewhat similar (0.12)
61 - "cat" is a bit related (0.01)
62
63 This "dark knowledge" helps the student learn better!
64
65
66 LOSS FUNCTION:
67 ββββββββββββββ
68
69 L = Ξ± Γ L_distill + (1-Ξ±) Γ L_task
70
71 Where:
72 - L_distill = KL(teacher_soft || student_soft) at temperature T
73 - L_task = CrossEntropy(student, hard_labels)
74 - Ξ± = weight (typically 0.5-0.9)
75 - T = temperature (typically 2-10)
76
77 Temperature T softens the distributions:
78 - T=1: Original distribution
79 - T>1: Softer (more uniform)
80 - Higher T reveals more relationships
81
82
83 DISTILLATION FOR TRANSLATION:
84 βββββββββββββββββββββββββββββ
85
86 Teacher: mBART-large (610M params, BLEU ~45)
87 Student: Our model (65M params)
88
89 Options:
90 1. Word-level distillation (soft labels per token)
91 2. Sequence-level distillation (translate with teacher, train student)
92 3. Hidden-state distillation (match intermediate representations)
93
94 Expected results:
95 - Without distillation: BLEU ~30-35
96 - With distillation: BLEU ~38-42
97
98 9-10x smaller, only 3-7 BLEU drop!
99 """)
100
101
102class DistillationTrainer:
103 """
104 Trainer for knowledge distillation.
105 """
106
107 def __init__(
108 self,
109 teacher_model: nn.Module,
110 student_model: nn.Module,
111 temperature: float = 4.0,
112 alpha: float = 0.7
113 ):
114 """
115 Initialize distillation trainer.
116
117 Args:
118 teacher_model: Large pre-trained model
119 student_model: Small model to train
120 temperature: Softmax temperature
121 alpha: Weight for distillation loss
122 """
123 self.teacher = teacher_model
124 self.student = student_model
125 self.temperature = temperature
126 self.alpha = alpha
127
128 # Freeze teacher
129 for param in self.teacher.parameters():
130 param.requires_grad = False
131 self.teacher.eval()
132
133 def distillation_loss(
134 self,
135 student_logits: torch.Tensor,
136 teacher_logits: torch.Tensor,
137 labels: torch.Tensor
138 ) -> Tuple[torch.Tensor, Dict[str, float]]:
139 """
140 Compute distillation loss.
141
142 Args:
143 student_logits: Student output logits
144 teacher_logits: Teacher output logits
145 labels: Ground truth labels
146
147 Returns:
148 Total loss, loss components dict
149 """
150 # Soft labels with temperature
151 T = self.temperature
152
153 student_soft = torch.log_softmax(student_logits / T, dim=-1)
154 teacher_soft = torch.softmax(teacher_logits / T, dim=-1)
155
156 # KL divergence loss (distillation)
157 distill_loss = nn.functional.kl_div(
158 student_soft,
159 teacher_soft,
160 reduction='batchmean'
161 ) * (T * T) # Scale by T^2 as per Hinton et al.
162
163 # Hard label loss (task)
164 task_loss = nn.functional.cross_entropy(
165 student_logits.view(-1, student_logits.size(-1)),
166 labels.view(-1),
167 ignore_index=-100
168 )
169
170 # Combined loss
171 total_loss = self.alpha * distill_loss + (1 - self.alpha) * task_loss
172
173 return total_loss, {
174 'distill_loss': distill_loss.item(),
175 'task_loss': task_loss.item(),
176 'total_loss': total_loss.item()
177 }
178
179 def train_step(
180 self,
181 batch: Dict[str, torch.Tensor]
182 ) -> Dict[str, float]:
183 """
184 Single training step with distillation.
185
186 Args:
187 batch: Input batch
188
189 Returns:
190 Loss values
191 """
192 self.student.train()
193
194 # Get teacher predictions (no grad)
195 with torch.no_grad():
196 teacher_outputs = self.teacher(**batch)
197 teacher_logits = teacher_outputs.logits
198
199 # Get student predictions
200 student_outputs = self.student(**batch)
201 student_logits = student_outputs.logits
202
203 # Compute loss
204 loss, loss_dict = self.distillation_loss(
205 student_logits,
206 teacher_logits,
207 batch['labels']
208 )
209
210 return loss, loss_dict
211
212
213def distillation_code_example():
214 """Show distillation code."""
215 print("\nDistillation Code Example")
216 print("=" * 60)
217
218 code = '''
219# Complete distillation training loop
220
221from transformers import MBartForConditionalGeneration
222from your_model import TransformerModel
223
224# Load teacher (pre-trained mBART)
225teacher = MBartForConditionalGeneration.from_pretrained(
226 "facebook/mbart-large-50-many-to-many-mmt"
227)
228teacher.eval()
229teacher.cuda()
230
231# Initialize student (your smaller model)
232student = TransformerModel(
233 vocab_size=32000,
234 d_model=256, # Smaller than teacher
235 num_heads=4,
236 num_layers=4, # Fewer layers
237 d_ff=512
238)
239student.cuda()
240
241# Distillation trainer
242trainer = DistillationTrainer(
243 teacher_model=teacher,
244 student_model=student,
245 temperature=4.0,
246 alpha=0.7
247)
248
249optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4)
250
251# Training loop
252for epoch in range(10):
253 for batch in train_loader:
254 batch = {k: v.cuda() for k, v in batch.items()}
255
256 loss, loss_dict = trainer.train_step(batch)
257
258 optimizer.zero_grad()
259 loss.backward()
260 optimizer.step()
261
262 print(f"Distill: {loss_dict['distill_loss']:.4f}, "
263 f"Task: {loss_dict['task_loss']:.4f}")
264
265# Save student
266torch.save(student.state_dict(), "student_model.pt")
267'''
268 print(code)
269
270
271distillation_code_example()Pruning
Removing Unnecessary Parameters
πpython
1def pruning_explained():
2 """Explain pruning techniques."""
3 print("=" * 60)
4 print("MODEL PRUNING")
5 print("=" * 60)
6
7 print("""
8 WHAT IS PRUNING?
9 ββββββββββββββββ
10
11 Remove weights that contribute little to the output.
12
13 Before pruning:
14 βββββββββββββββββββββββββββββββββββββββββββ
15 β 0.12 0.85 0.03 0.91 0.02 β
16 β 0.78 0.01 0.65 0.04 0.89 β
17 β 0.02 0.94 0.01 0.87 0.03 β
18 βββββββββββββββββββββββββββββββββββββββββββ
19
20 After pruning (50%):
21 βββββββββββββββββββββββββββββββββββββββββββ
22 β 0.00 0.85 0.00 0.91 0.00 β
23 β 0.78 0.00 0.65 0.00 0.89 β
24 β 0.00 0.94 0.00 0.87 0.00 β
25 βββββββββββββββββββββββββββββββββββββββββββ
26
27
28 PRUNING STRATEGIES:
29 βββββββββββββββββββ
30
31 1. MAGNITUDE PRUNING
32 Remove weights with smallest absolute value.
33 Simple and effective.
34
35 2. STRUCTURED PRUNING
36 Remove entire neurons, attention heads, or layers.
37 Actually reduces computation (not just memory).
38
39 3. MOVEMENT PRUNING
40 Remove weights that move toward zero during training.
41 Better for fine-tuning.
42
43
44 UNSTRUCTURED VS STRUCTURED:
45 βββββββββββββββββββββββββββ
46
47 Unstructured (sparse):
48 ββββββββββββββββββββββ
49 β x 0 x 0 0 x 0 x β Random zeros
50 β 0 x 0 x x 0 x 0 β Memory savings only
51 β x 0 0 x 0 x 0 x β Need special hardware
52 ββββββββββββββββββββββ
53
54 Structured (dense):
55 ββββββββββββββββββββββ
56 β x x x x x x x x β Remove entire rows/columns
57 β 0 0 0 0 0 0 0 0 β Actual speedup on any hardware
58 β x x x x x x x x β Smaller model
59 ββββββββββββββββββββββ
60
61
62 PRUNING ATTENTION HEADS:
63 ββββββββββββββββββββββββ
64
65 Not all attention heads are equally important!
66
67 Head importance analysis:
68 ββββββββββββββββββββββββββββββββββββββββββββ
69 β Layer 1: [0.85, 0.12, 0.76, 0.05, ...] β
70 β Layer 2: [0.92, 0.88, 0.03, 0.79, ...] β
71 β Layer 3: [0.45, 0.67, 0.89, 0.21, ...] β
72 ββββββββββββββββββββββββββββββββββββββββββββ
73
74 Can remove heads with low importance scores
75 with minimal quality loss!
76 """)
77
78
79class MagnitudePruner:
80 """
81 Simple magnitude-based pruning.
82 """
83
84 def __init__(
85 self,
86 model: nn.Module,
87 sparsity: float = 0.5
88 ):
89 """
90 Initialize pruner.
91
92 Args:
93 model: Model to prune
94 sparsity: Target sparsity (0.5 = 50% zeros)
95 """
96 self.model = model
97 self.sparsity = sparsity
98
99 def compute_threshold(self) -> float:
100 """Compute global magnitude threshold."""
101 all_weights = []
102 for name, param in self.model.named_parameters():
103 if 'weight' in name and param.dim() >= 2:
104 all_weights.append(param.abs().flatten())
105
106 all_weights = torch.cat(all_weights)
107 threshold = torch.quantile(all_weights, self.sparsity)
108 return threshold.item()
109
110 def prune(self) -> Dict[str, float]:
111 """
112 Apply magnitude pruning.
113
114 Returns:
115 Pruning statistics
116 """
117 threshold = self.compute_threshold()
118
119 total_params = 0
120 pruned_params = 0
121
122 for name, param in self.model.named_parameters():
123 if 'weight' in name and param.dim() >= 2:
124 mask = param.abs() >= threshold
125 param.data *= mask.float()
126
127 total_params += param.numel()
128 pruned_params += (~mask).sum().item()
129
130 actual_sparsity = pruned_params / total_params
131
132 return {
133 'target_sparsity': self.sparsity,
134 'actual_sparsity': actual_sparsity,
135 'pruned_params': pruned_params,
136 'total_params': total_params
137 }
138
139
140class HeadPruner:
141 """
142 Prune attention heads based on importance.
143 """
144
145 def __init__(self, model: nn.Module):
146 self.model = model
147
148 def compute_head_importance(
149 self,
150 dataloader: torch.utils.data.DataLoader,
151 num_batches: int = 100
152 ) -> Dict[str, torch.Tensor]:
153 """
154 Compute importance scores for each attention head.
155
156 Uses gradient-based importance.
157 """
158 head_importance = {}
159
160 self.model.eval()
161
162 # Would iterate through layers and compute importance
163 # Simplified for demonstration
164
165 return head_importance
166
167 def prune_heads(
168 self,
169 heads_to_prune: Dict[int, List[int]]
170 ):
171 """
172 Prune specific heads.
173
174 Args:
175 heads_to_prune: {layer_idx: [head_indices]}
176 """
177 # Would modify model to remove heads
178 pass
179
180
181def demonstrate_pruning():
182 """Demonstrate pruning."""
183 print("\nPruning Demonstration")
184 print("=" * 60)
185
186 # Create a model layer
187 layer = nn.Linear(256, 256)
188 print(f"Original parameters: {layer.weight.numel():,}")
189 print(f"Original sparsity: {(layer.weight == 0).sum().item() / layer.weight.numel():.2%}")
190
191 # Apply pruning
192 model = nn.Sequential(layer)
193 pruner = MagnitudePruner(model, sparsity=0.5)
194 stats = pruner.prune()
195
196 print(f"\nAfter 50% pruning:")
197 print(f" Target sparsity: {stats['target_sparsity']:.2%}")
198 print(f" Actual sparsity: {stats['actual_sparsity']:.2%}")
199 print(f" Pruned parameters: {stats['pruned_params']:,}")
200
201
202demonstrate_pruning()Optimization Best Practices
Production Checklist
πpython
1def optimization_best_practices():
2 """Summary of optimization best practices."""
3 print("=" * 70)
4 print("OPTIMIZATION BEST PRACTICES")
5 print("=" * 70)
6
7 print("""
8 RECOMMENDED OPTIMIZATION PIPELINE:
9 ββββββββββββββββββββββββββββββββββ
10
11 1. START WITH EVALUATION
12 β‘ Measure baseline latency
13 β‘ Measure baseline memory
14 β‘ Record quality metrics (BLEU)
15
16 2. EASY WINS (Do these first)
17 β‘ Use FP16/BF16 inference
18 β‘ Enable Flash Attention
19 β‘ Implement KV caching
20 β‘ Optimize batch size
21
22 3. QUANTIZATION (Most impactful)
23 β‘ Try dynamic INT8 first
24 β‘ If quality drops, try QAT
25 β‘ Benchmark INT4 if size critical
26
27 4. MODEL ARCHITECTURE (If needed)
28 β‘ Prune attention heads
29 β‘ Reduce layers if acceptable
30 β‘ Consider distillation
31
32 5. INFERENCE ENGINE (For production)
33 β‘ Export to ONNX
34 β‘ Use TensorRT (NVIDIA)
35 β‘ Consider vLLM for serving
36
37
38 QUALITY VS SPEED TRADE-OFFS:
39 ββββββββββββββββββββββββββββ
40
41 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
42 β Optimization β Speed β Quality β Recommendation β
43 ββββββββββββββββββββββΌββββββββββΌβββββββββββΌββββββββββββββββββββββ€
44 β FP32 β FP16 β +50% β ~Same β Always do β
45 β FP16 β INT8 β +100% β -0.5% β Usually do β
46 β INT8 β INT4 β +50% β -2% β If size critical β
47 β 6L β 4L β +50% β -3% β Test carefully β
48 β Distillation β +200% β -5% β For major savings β
49 β KV Cache β +300% β Same β Always for gen β
50 β Flash Attn β +200% β Same β Always if possible β
51 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
52
53
54 COMMON MISTAKES TO AVOID:
55 βββββββββββββββββββββββββ
56
57 β Optimizing before measuring baseline
58 β Applying all optimizations at once
59 β Ignoring quality metrics
60 β Not testing on representative data
61 β Forgetting to benchmark throughput (not just latency)
62
63
64 MONITORING IN PRODUCTION:
65 βββββββββββββββββββββββββ
66
67 Track these metrics:
68 - P50, P95, P99 latency
69 - Throughput (requests/second)
70 - GPU utilization
71 - Memory usage
72 - Quality metrics (periodic)
73 - Error rates
74 """)
75
76
77optimization_best_practices()Summary
Optimization Techniques Summary
| Technique | Speedup | Memory | Quality Impact | Complexity |
|---|---|---|---|---|
| FP16 | 1.5-2x | 50% | None | Low |
| INT8 | 2-4x | 75% | <1% | Medium |
| Distillation | 2-10x | 50-90% | 2-5% | High |
| Pruning | 1.5-3x | 50-90% | 1-5% | Medium |
| Flash Attention | 2-4x | 50-90% | None | Low |
Quick Reference
Priority order for optimization:
- FP16/BF16 (free speedup)
- Flash Attention (if available)
- KV Caching (for generation)
- Dynamic INT8 quantization
- Model pruning/distillation (if needed)
- TensorRT/ONNX (for deployment)
Exercises
- Apply dynamic quantization to your translation model and measure speedup.
- Implement magnitude pruning and find the maximum sparsity with <1 BLEU drop.
- Train a distilled 4-layer model from your 6-layer model.
- Compare FP32, FP16, and INT8 inference latency.
- Profile your model to identify the slowest operations.
In the next section, we'll cover model export (ONNX, TensorRT) and serving infrastructure.