Chapter 16
22 min read
Section 71 of 76

Quantization and Efficiency

Optimization and Deployment

Learning Objectives

By the end of this section, you will be able to:

  1. Understand numerical precision formats (FP32, FP16, BF16, INT8, INT4) and their trade-offs
  2. Apply FP16/BF16 inference to halve memory usage with minimal quality loss
  3. Implement INT8 quantization using post-training quantization (PTQ) and quantization-aware training (QAT)
  4. Apply model pruning techniques to reduce model size and computation
  5. Export models to ONNX for deployment on various hardware platforms

The Memory Bottleneck

Diffusion models like Stable Diffusion XL have billions of parameters, creating significant memory challenges. A single SDXL model in full precision requires:

ComponentParametersFP32 SizeFP16 Size
U-Net2.6B10.4 GB5.2 GB
VAE84M336 MB168 MB
Text Encoder (CLIP)123M492 MB246 MB
Text Encoder (OpenCLIP)694M2.8 GB1.4 GB
Total~3.5B~14 GB~7 GB
The Challenge: Running SDXL at full precision requires 14+ GB of VRAM just for model weights, before accounting for activations, gradients, or batch processing. This exceeds the capacity of most consumer GPUs.

Memory usage during inference includes:

  • Model weights: The parameters stored in GPU memory
  • Activations: Intermediate tensors during forward pass (scales with batch size and resolution)
  • Optimizer states: Not needed for inference, but 2-3x weight size for training
  • KV cache: For attention computation (grows with sequence length)

Precision Formats

Modern deep learning uses various numerical formats, each with different trade-offs between precision, memory, and computational efficiency:

FormatBitsRangePrecisionUse Case
FP3232~10^38~7 decimal digitsTraining (default)
TF3219~10^38~3 decimal digitsAmpere+ training
FP1616~65504~3 decimal digitsInference, mixed training
BF1616~10^38~2 decimal digitsTraining, inference
INT88-128 to 127Integer onlyQuantized inference
INT44-8 to 7Integer onlyExtreme compression

Understanding the Formats

🐍python
1import torch
2import numpy as np
3
4# FP32: Standard single precision
5fp32_val = torch.tensor(3.14159265358979, dtype=torch.float32)
6print(f"FP32: {fp32_val.item():.15f}")  # 3.141592741012573
7
8# FP16: Half precision - smaller range, same relative precision
9fp16_val = torch.tensor(3.14159265358979, dtype=torch.float16)
10print(f"FP16: {fp16_val.item():.15f}")  # 3.140625000000000
11
12# BF16: Brain Float - FP32 range, reduced precision
13bf16_val = torch.tensor(3.14159265358979, dtype=torch.bfloat16)
14print(f"BF16: {bf16_val.item():.15f}")  # 3.140625000000000
15
16# Check memory reduction
17x_fp32 = torch.randn(1000, 1000, dtype=torch.float32)
18x_fp16 = x_fp32.half()
19x_bf16 = x_fp32.bfloat16()
20
21print(f"\nMemory usage:")
22print(f"FP32: {x_fp32.element_size() * x_fp32.numel() / 1e6:.1f} MB")
23print(f"FP16: {x_fp16.element_size() * x_fp16.numel() / 1e6:.1f} MB")
24print(f"BF16: {x_bf16.element_size() * x_bf16.numel() / 1e6:.1f} MB")

FP16 vs BF16

FP16 has higher precision but can overflow/underflow with large values. BF16 has the same range as FP32 but lower precision, making it more stable for training. For inference, both work well in most cases.

FP16 and BF16 Inference

The simplest efficiency improvement is switching to half-precision inference. This cuts memory usage in half with typically negligible quality loss.

🐍python
1from diffusers import StableDiffusionXLPipeline
2import torch
3
4# Load in FP16 (recommended for most cases)
5pipe_fp16 = StableDiffusionXLPipeline.from_pretrained(
6    "stabilityai/stable-diffusion-xl-base-1.0",
7    torch_dtype=torch.float16,
8    variant="fp16",  # Load pre-converted FP16 weights
9    use_safetensors=True,
10)
11pipe_fp16.to("cuda")
12
13# For BF16 (better on newer GPUs like A100, H100)
14pipe_bf16 = StableDiffusionXLPipeline.from_pretrained(
15    "stabilityai/stable-diffusion-xl-base-1.0",
16    torch_dtype=torch.bfloat16,
17)
18pipe_bf16.to("cuda")
19
20# Memory comparison
21def get_gpu_memory():
22    return torch.cuda.memory_allocated() / 1e9
23
24print(f"FP16 memory: {get_gpu_memory():.2f} GB")
25
26# Enable attention optimizations (works with FP16/BF16)
27pipe_fp16.enable_xformers_memory_efficient_attention()
28# Or use PyTorch 2.0 native:
29# pipe_fp16.enable_attention_slicing()
30
31# Generate image
32image = pipe_fp16(
33    "A majestic eagle soaring over mountains",
34    num_inference_steps=30,
35).images[0]

Automatic Mixed Precision (AMP)

For custom models, use PyTorch's AMP to automatically manage precision:

🐍python
1import torch
2from torch.cuda.amp import autocast
3
4class DiffusionInference:
5    def __init__(self, model, scheduler):
6        self.model = model.half().cuda()  # Convert model to FP16
7        self.scheduler = scheduler
8
9    @torch.no_grad()
10    def sample(self, noise, timesteps, condition=None):
11        x = noise
12
13        for t in timesteps:
14            # Autocast handles precision automatically
15            with autocast(dtype=torch.float16):
16                # Model forward pass in FP16
17                noise_pred = self.model(x, t, condition)
18
19                # Scheduler step (may need FP32 for numerical stability)
20                with autocast(enabled=False):
21                    x = self.scheduler.step(
22                        noise_pred.float(), t, x.float()
23                    ).prev_sample.half()
24
25        return x
26
27# For critical operations, disable autocast
28def compute_loss(pred, target):
29    with autocast(enabled=False):
30        # Loss computation in FP32 for numerical stability
31        return torch.nn.functional.mse_loss(pred.float(), target.float())

INT8 Quantization

INT8 quantization represents weights and activations using 8-bit integers instead of floating point, reducing memory by 4x compared to FP32 (2x compared to FP16).

Quantization Basics

The quantization formula maps floating-point values to integers:

q=round(xs)+zq = \text{round}\left(\frac{x}{s}\right) + z

Where ss is the scale factor and zz is the zero-point. Dequantization reverses this: x=s(qz)x = s(q - z).

🐍python
1import torch
2
3def quantize_tensor(x, num_bits=8):
4    """Symmetric quantization to INT8."""
5    # Compute scale based on max absolute value
6    max_val = x.abs().max()
7    scale = max_val / (2**(num_bits-1) - 1)
8
9    # Quantize
10    q = torch.round(x / scale).clamp(
11        -(2**(num_bits-1)), 2**(num_bits-1) - 1
12    ).to(torch.int8)
13
14    return q, scale
15
16def dequantize_tensor(q, scale):
17    """Dequantize INT8 back to float."""
18    return q.float() * scale
19
20# Example
21x = torch.randn(1000, 1000)
22q, scale = quantize_tensor(x)
23x_reconstructed = dequantize_tensor(q, scale)
24
25# Measure error
26mse = torch.nn.functional.mse_loss(x, x_reconstructed)
27print(f"Quantization MSE: {mse.item():.6f}")
28print(f"Memory reduction: {x.element_size() / q.element_size()}x")

Post-Training Quantization (PTQ)

PTQ quantizes a pre-trained model without retraining, using calibration data to determine optimal scale factors:

🐍python
1from optimum.quanto import quantize, freeze, qint8
2from diffusers import StableDiffusionPipeline
3import torch
4
5# Load the model
6pipe = StableDiffusionPipeline.from_pretrained(
7    "runwayml/stable-diffusion-v1-5",
8    torch_dtype=torch.float16,
9)
10pipe.to("cuda")
11
12# Quantize the U-Net to INT8
13quantize(pipe.unet, weights=qint8, activations=qint8)
14
15# Calibrate with sample inputs (important for activation quantization)
16calibration_prompts = [
17    "A photo of a cat",
18    "A beautiful landscape",
19    "A portrait of a person",
20    "Abstract art with vibrant colors",
21]
22
23print("Calibrating quantized model...")
24for prompt in calibration_prompts:
25    with torch.no_grad():
26        _ = pipe(prompt, num_inference_steps=10)
27
28# Freeze the quantization parameters
29freeze(pipe.unet)
30
31# Now inference uses INT8
32image = pipe(
33    "A cyberpunk city at night",
34    num_inference_steps=30,
35).images[0]
36
37# Check memory usage
38print(f"Quantized model memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

Using bitsandbytes for 8-bit

🐍python
1from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
2import torch
3
4# Configure 8-bit quantization
5quantization_config = BitsAndBytesConfig(
6    load_in_8bit=True,
7    llm_int8_enable_fp32_cpu_offload=False,
8)
9
10# Load with 8-bit quantization
11pipe = StableDiffusionXLPipeline.from_pretrained(
12    "stabilityai/stable-diffusion-xl-base-1.0",
13    quantization_config=quantization_config,
14    device_map="auto",  # Automatic device placement
15)
16
17# Memory footprint is significantly reduced
18print(f"8-bit memory usage: {torch.cuda.memory_allocated()/1e9:.2f} GB")
19
20# Generate normally
21image = pipe(
22    "A serene Japanese garden in autumn",
23    num_inference_steps=30,
24).images[0]

INT4 and Extreme Quantization

INT4 quantization pushes compression further, using only 4 bits per weight. This requires careful handling to maintain quality.

🐍python
1from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
2import torch
3
4# 4-bit quantization with NF4 (Neural Float 4)
5quantization_config = BitsAndBytesConfig(
6    load_in_4bit=True,
7    bnb_4bit_quant_type="nf4",  # NF4 is optimized for neural network weights
8    bnb_4bit_use_double_quant=True,  # Double quantization for better compression
9    bnb_4bit_compute_dtype=torch.float16,  # Compute in FP16
10)
11
12pipe = StableDiffusionXLPipeline.from_pretrained(
13    "stabilityai/stable-diffusion-xl-base-1.0",
14    quantization_config=quantization_config,
15    device_map="auto",
16)
17
18print(f"4-bit memory usage: {torch.cuda.memory_allocated()/1e9:.2f} GB")
19
20# Quality may be slightly reduced
21image = pipe(
22    "A detailed oil painting of a sailing ship",
23    num_inference_steps=30,
24).images[0]

Comparison of Quantization Levels

PrecisionMemorySpeedQualityBest For
FP3214 GB1.0x100%Training, debugging
FP16/BF167 GB1.2-1.5x~99.9%Production default
INT83.5 GB1.5-2x~99%Memory-constrained
INT41.8 GB2-3x~95-98%Consumer GPUs

Model Pruning

Pruning removes unnecessary weights or entire network components to reduce model size and computation. There are several approaches:

Unstructured Pruning

Removes individual weights below a threshold, creating sparse matrices:

🐍python
1import torch
2import torch.nn.utils.prune as prune
3
4def prune_model(model, amount=0.3):
5    """Apply unstructured L1 pruning to all linear layers."""
6    for name, module in model.named_modules():
7        if isinstance(module, torch.nn.Linear):
8            prune.l1_unstructured(module, name='weight', amount=amount)
9            prune.remove(module, 'weight')  # Make pruning permanent
10
11    return model
12
13def count_parameters(model):
14    """Count non-zero parameters."""
15    total = sum(p.numel() for p in model.parameters())
16    nonzero = sum((p != 0).sum().item() for p in model.parameters())
17    return total, nonzero, nonzero/total
18
19# Example: prune a U-Net
20unet = ...  # Your U-Net model
21total, nonzero, ratio = count_parameters(unet)
22print(f"Before pruning: {total:,} params, {ratio:.1%} non-zero")
23
24pruned_unet = prune_model(unet, amount=0.3)
25total, nonzero, ratio = count_parameters(pruned_unet)
26print(f"After pruning: {total:,} params, {ratio:.1%} non-zero")

Structured Pruning

Removes entire channels or attention heads, enabling actual speedups without special sparse hardware:

🐍python
1import torch.nn.utils.prune as prune
2
3def structured_prune_conv(module, amount=0.2):
4    """Prune entire output channels based on L2 norm."""
5    prune.ln_structured(
6        module,
7        name='weight',
8        amount=amount,
9        n=2,  # L2 norm
10        dim=0,  # Prune output channels
11    )
12    return module
13
14def prune_attention_heads(attention_layer, heads_to_prune):
15    """Remove specific attention heads."""
16    # This requires architecture-aware modifications
17    num_heads = attention_layer.num_heads
18    head_dim = attention_layer.head_dim
19
20    # Create mask for heads to keep
21    mask = torch.ones(num_heads)
22    mask[heads_to_prune] = 0
23
24    # Apply to Q, K, V projections
25    with torch.no_grad():
26        for proj in [attention_layer.q_proj, attention_layer.k_proj, attention_layer.v_proj]:
27            weight = proj.weight.view(num_heads, head_dim, -1)
28            weight = weight * mask.view(-1, 1, 1)
29            proj.weight.data = weight.view(-1, weight.shape[-1])
30
31    return attention_layer

Knowledge Distillation for Pruned Models

After pruning, fine-tune with knowledge distillation to recover quality:

🐍python
1class PrunedModelDistillation:
2    """Recover quality of pruned model through distillation."""
3
4    def __init__(self, teacher_model, pruned_student):
5        self.teacher = teacher_model.eval()
6        self.student = pruned_student
7
8    def distillation_step(self, x, t, condition):
9        # Teacher output
10        with torch.no_grad():
11            teacher_output = self.teacher(x, t, condition)
12
13        # Student output
14        student_output = self.student(x, t, condition)
15
16        # Distillation loss
17        loss = torch.nn.functional.mse_loss(student_output, teacher_output)
18
19        # Optional: add original denoising loss
20        noise = ...  # Original noise
21        noise_loss = torch.nn.functional.mse_loss(student_output, noise)
22
23        return 0.5 * loss + 0.5 * noise_loss

ONNX Export and Optimization

ONNX (Open Neural Network Exchange) enables deploying models on various hardware platforms and inference engines like TensorRT, OpenVINO, and Core ML.

🐍python
1from optimum.exporters.onnx import main_export
2from optimum.onnxruntime import ORTStableDiffusionPipeline
3import torch
4
5# Export Stable Diffusion to ONNX
6main_export(
7    model_name_or_path="runwayml/stable-diffusion-v1-5",
8    output="sd-onnx/",
9    task="stable-diffusion",
10    opset=17,  # ONNX opset version
11    fp16=True,  # Export in FP16
12)
13
14# Load and run with ONNX Runtime
15pipe = ORTStableDiffusionPipeline.from_pretrained(
16    "sd-onnx/",
17    provider="CUDAExecutionProvider",  # Or "TensorrtExecutionProvider"
18)
19
20image = pipe(
21    "A beautiful sunset over the ocean",
22    num_inference_steps=30,
23).images[0]

TensorRT Optimization

For NVIDIA GPUs, TensorRT provides the best inference performance:

🐍python
1# Install: pip install tensorrt torch-tensorrt
2
3import torch
4import torch_tensorrt
5
6# Convert PyTorch model to TensorRT
7def optimize_for_tensorrt(model, sample_input):
8    """Compile model with TensorRT."""
9    # Enable TensorRT optimizations
10    optimized_model = torch_tensorrt.compile(
11        model,
12        inputs=[sample_input],
13        enabled_precisions={torch.half},  # FP16 inference
14        workspace_size=1 << 30,  # 1GB workspace
15        min_block_size=3,  # Minimum ops to include
16        truncate_long_and_double=True,
17    )
18    return optimized_model
19
20# For diffusers, use built-in TensorRT support
21from diffusers import StableDiffusionPipeline
22import torch
23
24pipe = StableDiffusionPipeline.from_pretrained(
25    "runwayml/stable-diffusion-v1-5",
26    torch_dtype=torch.float16,
27)
28pipe.to("cuda")
29
30# Enable TensorRT (requires additional setup)
31# pipe.enable_tensorrt()
32
33# Or use torch.compile with inductor backend
34pipe.unet = torch.compile(
35    pipe.unet,
36    mode="reduce-overhead",  # Optimize for latency
37    backend="inductor",
38)

Platform-Specific Exports

PlatformFormatOptimizationUse Case
NVIDIA GPUTensorRTKernel fusion, FP16Servers, cloud
Intel CPU/GPUOpenVINOINT8, threadingEdge, on-prem
Apple SiliconCore MLNeural EnginemacOS, iOS
MobileONNX + QNNINT8, NPUAndroid, iOS
WebONNX.js/WebGPUWebGL shadersBrowser apps

Practical Workflows

Complete Optimization Pipeline

🐍python
1import torch
2from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
3from optimum.quanto import quantize, freeze, qint8
4
5class OptimizedDiffusionPipeline:
6    """Production-ready optimized diffusion pipeline."""
7
8    def __init__(
9        self,
10        model_id: str,
11        precision: str = "fp16",  # fp32, fp16, bf16, int8, int4
12        use_xformers: bool = True,
13        compile_mode: str = None,  # None, "reduce-overhead", "max-autotune"
14    ):
15        self.precision = precision
16
17        # Load with appropriate precision
18        if precision == "int4":
19            config = BitsAndBytesConfig(
20                load_in_4bit=True,
21                bnb_4bit_quant_type="nf4",
22                bnb_4bit_compute_dtype=torch.float16,
23            )
24            self.pipe = StableDiffusionXLPipeline.from_pretrained(
25                model_id, quantization_config=config, device_map="auto"
26            )
27        elif precision == "int8":
28            config = BitsAndBytesConfig(load_in_8bit=True)
29            self.pipe = StableDiffusionXLPipeline.from_pretrained(
30                model_id, quantization_config=config, device_map="auto"
31            )
32        else:
33            dtype = {
34                "fp32": torch.float32,
35                "fp16": torch.float16,
36                "bf16": torch.bfloat16,
37            }[precision]
38            self.pipe = StableDiffusionXLPipeline.from_pretrained(
39                model_id, torch_dtype=dtype, variant="fp16"
40            )
41            self.pipe.to("cuda")
42
43        # Enable memory-efficient attention
44        if use_xformers:
45            try:
46                self.pipe.enable_xformers_memory_efficient_attention()
47            except Exception:
48                self.pipe.enable_attention_slicing()
49
50        # Compile for extra speed (PyTorch 2.0+)
51        if compile_mode:
52            self.pipe.unet = torch.compile(
53                self.pipe.unet, mode=compile_mode
54            )
55
56        # Warmup
57        self._warmup()
58
59    def _warmup(self):
60        """Warmup the pipeline to trigger compilation."""
61        _ = self.pipe(
62            "warmup", num_inference_steps=2, output_type="latent"
63        )
64        torch.cuda.synchronize()
65
66    def generate(self, prompt: str, **kwargs):
67        """Generate image with all optimizations."""
68        return self.pipe(prompt, **kwargs).images[0]
69
70    def benchmark(self, prompt: str, num_runs: int = 10):
71        """Benchmark inference latency."""
72        import time
73
74        # Warmup
75        for _ in range(3):
76            self.generate(prompt, num_inference_steps=4)
77
78        # Benchmark
79        torch.cuda.synchronize()
80        start = time.time()
81
82        for _ in range(num_runs):
83            self.generate(prompt, num_inference_steps=20)
84            torch.cuda.synchronize()
85
86        elapsed = (time.time() - start) / num_runs
87        return {
88            "latency_ms": elapsed * 1000,
89            "memory_gb": torch.cuda.max_memory_allocated() / 1e9,
90        }
91
92# Usage
93pipeline = OptimizedDiffusionPipeline(
94    "stabilityai/stable-diffusion-xl-base-1.0",
95    precision="fp16",
96    use_xformers=True,
97    compile_mode="reduce-overhead",
98)
99
100# Benchmark
101results = pipeline.benchmark("A cat sitting on a couch")
102print(f"Latency: {results['latency_ms']:.1f}ms")
103print(f"Memory: {results['memory_gb']:.2f}GB")

Memory Optimization Tips

🐍python
1from diffusers import StableDiffusionXLPipeline
2import torch
3
4pipe = StableDiffusionXLPipeline.from_pretrained(
5    "stabilityai/stable-diffusion-xl-base-1.0",
6    torch_dtype=torch.float16,
7    variant="fp16",
8)
9
10# 1. Enable sequential CPU offloading (saves most memory)
11pipe.enable_sequential_cpu_offload()
12
13# 2. Or model CPU offloading (less memory savings, faster)
14# pipe.enable_model_cpu_offload()
15
16# 3. Enable VAE slicing for high-res images
17pipe.enable_vae_slicing()
18
19# 4. Enable VAE tiling for very high-res
20pipe.enable_vae_tiling()
21
22# 5. Use attention slicing
23pipe.enable_attention_slicing(slice_size="auto")
24
25# Now can run on 4GB VRAM (slowly due to offloading)
26image = pipe(
27    "A beautiful landscape",
28    num_inference_steps=30,
29).images[0]

Summary

Quantization and efficiency techniques are essential for deploying diffusion models in production:

  1. FP16/BF16: Simple 2x memory reduction with minimal quality loss - use as baseline for all deployments
  2. INT8 Quantization: 4x compression from FP32, requires calibration but maintains good quality
  3. INT4 Quantization: Extreme compression for memory-constrained devices, some quality trade-off
  4. Model Pruning: Remove unnecessary weights/heads, requires fine-tuning to recover quality
  5. ONNX/TensorRT: Platform-specific optimization for production deployment
Production Recommendation: Start with FP16 + xformers + torch.compile for the best latency-memory trade-off. Move to INT8 or INT4 only if memory constraints require it.
Looking Ahead: In the next section, we'll explore how to serve these optimized models at scale using TorchServe, Triton, and cloud infrastructure.