Chapter 16
25 min read
Section 72 of 76

Serving Diffusion Models

Optimization and Deployment

Learning Objectives

By the end of this section, you will be able to:

  1. Design production-ready APIs for diffusion model inference with proper request handling
  2. Deploy models using TorchServe and Triton with optimized configurations
  3. Implement dynamic batching to maximize GPU utilization
  4. Scale horizontally using Kubernetes and GPU orchestration
  5. Monitor and optimize production diffusion model services

Deployment Challenges

Serving diffusion models in production presents unique challenges compared to traditional ML models:

ChallengeDescriptionImpact
Long inference time2-30 seconds per imageRequires async processing
High GPU memory6-14 GB per modelLimited concurrent requests
Variable workloadsBursty traffic patternsAutoscaling complexity
Large model files2-10 GB per modelSlow cold starts
User expectationsReal-time feedbackProgress streaming needed
Key Insight: Unlike classification models that return in milliseconds, diffusion models require fundamentally different serving patterns - async processing, progress updates, and careful resource management.

Serving Frameworks

TorchServe

TorchServe is PyTorch's native serving solution, offering easy integration with PyTorch models and the diffusers library.

🐍python
1# diffusion_handler.py - TorchServe Handler
2import torch
3from ts.torch_handler.base_handler import BaseHandler
4from diffusers import StableDiffusionXLPipeline, LCMScheduler
5import io
6import base64
7from PIL import Image
8
9class DiffusionHandler(BaseHandler):
10    def __init__(self):
11        super().__init__()
12        self.pipe = None
13
14    def initialize(self, context):
15        """Load model on startup."""
16        self.manifest = context.manifest
17        model_dir = context.system_properties.get("model_dir")
18
19        # Load optimized pipeline
20        self.pipe = StableDiffusionXLPipeline.from_pretrained(
21            model_dir,
22            torch_dtype=torch.float16,
23            use_safetensors=True,
24        )
25        self.pipe.to("cuda")
26
27        # Apply optimizations
28        self.pipe.enable_xformers_memory_efficient_attention()
29
30        # Optional: Use LCM for faster inference
31        # self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
32
33        # Warmup
34        self.pipe("warmup", num_inference_steps=2, output_type="latent")
35        torch.cuda.synchronize()
36
37    def preprocess(self, requests):
38        """Parse incoming requests."""
39        inputs = []
40        for req in requests:
41            data = req.get("data") or req.get("body")
42            if isinstance(data, (bytes, bytearray)):
43                data = data.decode("utf-8")
44            if isinstance(data, str):
45                import json
46                data = json.loads(data)
47            inputs.append(data)
48        return inputs
49
50    def inference(self, inputs):
51        """Run diffusion inference."""
52        results = []
53
54        for input_data in inputs:
55            prompt = input_data.get("prompt", "")
56            negative_prompt = input_data.get("negative_prompt", "")
57            num_steps = input_data.get("num_inference_steps", 20)
58            guidance_scale = input_data.get("guidance_scale", 7.5)
59            seed = input_data.get("seed", None)
60
61            generator = None
62            if seed is not None:
63                generator = torch.Generator("cuda").manual_seed(seed)
64
65            with torch.inference_mode():
66                image = self.pipe(
67                    prompt=prompt,
68                    negative_prompt=negative_prompt,
69                    num_inference_steps=num_steps,
70                    guidance_scale=guidance_scale,
71                    generator=generator,
72                ).images[0]
73
74            results.append(image)
75
76        return results
77
78    def postprocess(self, outputs):
79        """Convert images to base64."""
80        results = []
81        for image in outputs:
82            buffer = io.BytesIO()
83            image.save(buffer, format="PNG")
84            img_base64 = base64.b64encode(buffer.getvalue()).decode()
85            results.append({"image": img_base64})
86        return results
bash
1# Package and deploy with TorchServe
2
3# 1. Create model archive
4torch-model-archiver \
5    --model-name stable-diffusion-xl \
6    --version 1.0 \
7    --handler diffusion_handler.py \
8    --extra-files "model_index.json,scheduler/*,tokenizer/*,unet/*,vae/*,text_encoder/*" \
9    --export-path model_store
10
11# 2. Start TorchServe
12torchserve \
13    --start \
14    --model-store model_store \
15    --models stable-diffusion-xl=stable-diffusion-xl.mar \
16    --ts-config config.properties
17
18# 3. Test endpoint
19curl -X POST http://localhost:8080/predictions/stable-diffusion-xl \
20    -H "Content-Type: application/json" \
21    -d '{"prompt": "A beautiful sunset over mountains"}'

NVIDIA Triton Inference Server

Triton offers advanced features like dynamic batching, model ensembles, and multi-framework support. It's ideal for high-throughput production deployments.

🐍python
1# model.py - Triton Python Backend
2import triton_python_backend_utils as pb_utils
3import torch
4import numpy as np
5from diffusers import StableDiffusionXLPipeline
6import json
7
8class TritonPythonModel:
9    def initialize(self, args):
10        """Load model during server startup."""
11        self.model_config = json.loads(args["model_config"])
12
13        # Load pipeline
14        self.pipe = StableDiffusionXLPipeline.from_pretrained(
15            "/models/sdxl-base",
16            torch_dtype=torch.float16,
17        )
18        self.pipe.to("cuda")
19        self.pipe.enable_xformers_memory_efficient_attention()
20
21        # Warmup
22        self.pipe("warmup", num_inference_steps=1, output_type="latent")
23
24    def execute(self, requests):
25        """Process inference requests."""
26        responses = []
27
28        for request in requests:
29            # Get input tensors
30            prompt = pb_utils.get_input_tensor_by_name(request, "prompt")
31            prompt = prompt.as_numpy()[0].decode("utf-8")
32
33            steps = pb_utils.get_input_tensor_by_name(request, "steps")
34            steps = int(steps.as_numpy()[0]) if steps else 20
35
36            # Generate image
37            with torch.inference_mode():
38                image = self.pipe(
39                    prompt=prompt,
40                    num_inference_steps=steps,
41                ).images[0]
42
43            # Convert to numpy
44            image_np = np.array(image)
45
46            # Create output tensor
47            output_tensor = pb_utils.Tensor("image", image_np)
48
49            # Create response
50            response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
51            responses.append(response)
52
53        return responses
54
55    def finalize(self):
56        """Cleanup on shutdown."""
57        del self.pipe
58        torch.cuda.empty_cache()
📝text
1# config.pbtxt - Triton Model Configuration
2name: "stable_diffusion"
3backend: "python"
4
5input [
6  {
7    name: "prompt"
8    data_type: TYPE_STRING
9    dims: [1]
10  },
11  {
12    name: "steps"
13    data_type: TYPE_INT32
14    dims: [1]
15    optional: true
16  }
17]
18
19output [
20  {
21    name: "image"
22    data_type: TYPE_UINT8
23    dims: [1024, 1024, 3]
24  }
25]
26
27instance_group [
28  {
29    count: 1
30    kind: KIND_GPU
31    gpus: [0]
32  }
33]
34
35# Enable dynamic batching
36dynamic_batching {
37  preferred_batch_size: [1, 2, 4]
38  max_queue_delay_microseconds: 100000
39}

API Design

A production API for diffusion models needs to handle async processing, progress updates, and result retrieval. Here's a robust design:

🐍python
1from fastapi import FastAPI, BackgroundTasks, HTTPException
2from pydantic import BaseModel, Field
3from typing import Optional
4import uuid
5import asyncio
6from enum import Enum
7import redis
8import json
9
10app = FastAPI(title="Diffusion API")
11redis_client = redis.Redis(host="localhost", port=6379)
12
13class JobStatus(str, Enum):
14    PENDING = "pending"
15    PROCESSING = "processing"
16    COMPLETED = "completed"
17    FAILED = "failed"
18
19class GenerationRequest(BaseModel):
20    prompt: str = Field(..., min_length=1, max_length=2000)
21    negative_prompt: Optional[str] = ""
22    num_inference_steps: int = Field(default=20, ge=1, le=100)
23    guidance_scale: float = Field(default=7.5, ge=1.0, le=20.0)
24    width: int = Field(default=1024, ge=512, le=2048)
25    height: int = Field(default=1024, ge=512, le=2048)
26    seed: Optional[int] = None
27
28class GenerationResponse(BaseModel):
29    job_id: str
30    status: JobStatus
31    queue_position: Optional[int] = None
32    estimated_time: Optional[float] = None
33
34class JobResult(BaseModel):
35    job_id: str
36    status: JobStatus
37    progress: Optional[float] = None
38    image_url: Optional[str] = None
39    error: Optional[str] = None
40
41# Endpoints
42@app.post("/generate", response_model=GenerationResponse)
43async def create_generation(
44    request: GenerationRequest,
45    background_tasks: BackgroundTasks
46):
47    """Submit a new generation job."""
48    job_id = str(uuid.uuid4())
49
50    # Store job in Redis
51    job_data = {
52        "status": JobStatus.PENDING,
53        "request": request.dict(),
54        "progress": 0,
55    }
56    redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
57
58    # Add to processing queue
59    redis_client.rpush("generation_queue", job_id)
60    queue_position = redis_client.llen("generation_queue")
61
62    # Estimate time (rough: 2s per step for SDXL)
63    estimated_time = request.num_inference_steps * 0.1 * queue_position
64
65    return GenerationResponse(
66        job_id=job_id,
67        status=JobStatus.PENDING,
68        queue_position=queue_position,
69        estimated_time=estimated_time,
70    )
71
72@app.get("/jobs/{job_id}", response_model=JobResult)
73async def get_job_status(job_id: str):
74    """Check job status and get results."""
75    job_data = redis_client.get(f"job:{job_id}")
76
77    if not job_data:
78        raise HTTPException(status_code=404, detail="Job not found")
79
80    job = json.loads(job_data)
81
82    return JobResult(
83        job_id=job_id,
84        status=job["status"],
85        progress=job.get("progress"),
86        image_url=job.get("image_url"),
87        error=job.get("error"),
88    )
89
90@app.delete("/jobs/{job_id}")
91async def cancel_job(job_id: str):
92    """Cancel a pending job."""
93    job_data = redis_client.get(f"job:{job_id}")
94
95    if not job_data:
96        raise HTTPException(status_code=404, detail="Job not found")
97
98    job = json.loads(job_data)
99
100    if job["status"] == JobStatus.PROCESSING:
101        raise HTTPException(status_code=400, detail="Cannot cancel processing job")
102
103    redis_client.delete(f"job:{job_id}")
104    redis_client.lrem("generation_queue", 1, job_id)
105
106    return {"message": "Job cancelled"}
107
108# WebSocket for real-time progress
109from fastapi import WebSocket
110
111@app.websocket("/ws/jobs/{job_id}")
112async def job_progress_websocket(websocket: WebSocket, job_id: str):
113    """Stream job progress updates."""
114    await websocket.accept()
115
116    try:
117        while True:
118            job_data = redis_client.get(f"job:{job_id}")
119            if not job_data:
120                await websocket.close(code=1000, reason="Job not found")
121                break
122
123            job = json.loads(job_data)
124            await websocket.send_json({
125                "status": job["status"],
126                "progress": job.get("progress", 0),
127            })
128
129            if job["status"] in [JobStatus.COMPLETED, JobStatus.FAILED]:
130                break
131
132            await asyncio.sleep(0.5)
133
134    except Exception:
135        pass

Worker Process

🐍python
1# worker.py - Background worker for processing jobs
2import torch
3from diffusers import StableDiffusionXLPipeline
4from diffusers.callbacks import PipelineCallback
5import redis
6import json
7import boto3  # For S3 upload
8
9redis_client = redis.Redis(host="localhost", port=6379)
10s3_client = boto3.client("s3")
11
12class ProgressCallback(PipelineCallback):
13    def __init__(self, job_id, redis_client):
14        self.job_id = job_id
15        self.redis = redis_client
16
17    def __call__(self, pipe, step, timestep, callback_kwargs):
18        total_steps = pipe.num_inference_steps
19        progress = (step + 1) / total_steps
20
21        # Update progress in Redis
22        job_data = json.loads(self.redis.get(f"job:{self.job_id}"))
23        job_data["progress"] = progress
24        self.redis.setex(f"job:{self.job_id}", 3600, json.dumps(job_data))
25
26        return callback_kwargs
27
28def process_jobs():
29    """Main worker loop."""
30    # Load model once
31    pipe = StableDiffusionXLPipeline.from_pretrained(
32        "stabilityai/stable-diffusion-xl-base-1.0",
33        torch_dtype=torch.float16,
34    )
35    pipe.to("cuda")
36    pipe.enable_xformers_memory_efficient_attention()
37
38    print("Worker ready, waiting for jobs...")
39
40    while True:
41        # Block until job available
42        _, job_id = redis_client.blpop("generation_queue", timeout=0)
43        job_id = job_id.decode("utf-8")
44
45        try:
46            # Get job data
47            job_data = json.loads(redis_client.get(f"job:{job_id}"))
48            request = job_data["request"]
49
50            # Update status
51            job_data["status"] = "processing"
52            redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
53
54            # Generate image
55            generator = None
56            if request.get("seed"):
57                generator = torch.Generator("cuda").manual_seed(request["seed"])
58
59            callback = ProgressCallback(job_id, redis_client)
60
61            image = pipe(
62                prompt=request["prompt"],
63                negative_prompt=request.get("negative_prompt", ""),
64                num_inference_steps=request["num_inference_steps"],
65                guidance_scale=request["guidance_scale"],
66                width=request["width"],
67                height=request["height"],
68                generator=generator,
69                callback_on_step_end=callback,
70            ).images[0]
71
72            # Upload to S3
73            buffer = io.BytesIO()
74            image.save(buffer, format="PNG")
75            buffer.seek(0)
76
77            image_key = f"generations/{job_id}.png"
78            s3_client.upload_fileobj(
79                buffer,
80                "diffusion-outputs",
81                image_key,
82                ExtraArgs={"ContentType": "image/png"},
83            )
84
85            # Update job with result
86            job_data["status"] = "completed"
87            job_data["progress"] = 1.0
88            job_data["image_url"] = f"https://cdn.example.com/{image_key}"
89            redis_client.setex(f"job:{job_id}", 86400, json.dumps(job_data))
90
91        except Exception as e:
92            # Handle failure
93            job_data["status"] = "failed"
94            job_data["error"] = str(e)
95            redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
96
97if __name__ == "__main__":
98    process_jobs()

Batching Strategies

Batching multiple requests together significantly improves GPU utilization and throughput. However, diffusion models present unique batching challenges.

Dynamic Batching

🐍python
1import asyncio
2from dataclasses import dataclass
3from typing import List, Dict, Any
4import torch
5from collections import deque
6import time
7
8@dataclass
9class BatchRequest:
10    request_id: str
11    prompt: str
12    num_steps: int
13    future: asyncio.Future
14
15class DynamicBatcher:
16    """
17    Collects requests and processes them in batches.
18    Optimizes for GPU utilization while maintaining reasonable latency.
19    """
20
21    def __init__(
22        self,
23        pipe,
24        max_batch_size: int = 4,
25        max_wait_time: float = 0.5,  # seconds
26        same_steps_only: bool = True,  # Batch only same-step requests
27    ):
28        self.pipe = pipe
29        self.max_batch_size = max_batch_size
30        self.max_wait_time = max_wait_time
31        self.same_steps_only = same_steps_only
32
33        self.request_queue: deque[BatchRequest] = deque()
34        self.processing = False
35        self._lock = asyncio.Lock()
36
37    async def submit(self, prompt: str, num_steps: int = 20) -> Dict[str, Any]:
38        """Submit a request and wait for result."""
39        future = asyncio.Future()
40        request = BatchRequest(
41            request_id=str(uuid.uuid4()),
42            prompt=prompt,
43            num_steps=num_steps,
44            future=future,
45        )
46
47        async with self._lock:
48            self.request_queue.append(request)
49
50            # Start processing if not already running
51            if not self.processing:
52                asyncio.create_task(self._process_batches())
53
54        return await future
55
56    async def _process_batches(self):
57        """Main processing loop."""
58        self.processing = True
59
60        while True:
61            batch = await self._collect_batch()
62
63            if not batch:
64                self.processing = False
65                break
66
67            # Process batch
68            try:
69                results = await self._run_batch(batch)
70
71                # Deliver results
72                for request, result in zip(batch, results):
73                    request.future.set_result(result)
74
75            except Exception as e:
76                # Propagate error to all requests
77                for request in batch:
78                    request.future.set_exception(e)
79
80    async def _collect_batch(self) -> List[BatchRequest]:
81        """Collect requests into a batch."""
82        batch = []
83        start_time = time.time()
84        target_steps = None
85
86        while len(batch) < self.max_batch_size:
87            elapsed = time.time() - start_time
88
89            if elapsed >= self.max_wait_time and batch:
90                break
91
92            try:
93                # Wait for next request with timeout
94                timeout = max(0.01, self.max_wait_time - elapsed)
95                await asyncio.sleep(0.01)
96
97                if self.request_queue:
98                    request = self.request_queue.popleft()
99
100                    # Check if compatible with current batch
101                    if self.same_steps_only and target_steps is not None:
102                        if request.num_steps != target_steps:
103                            # Put back and continue
104                            self.request_queue.appendleft(request)
105                            continue
106
107                    batch.append(request)
108                    target_steps = request.num_steps
109
110            except asyncio.TimeoutError:
111                break
112
113        return batch
114
115    async def _run_batch(self, batch: List[BatchRequest]) -> List[Dict]:
116        """Run inference on a batch."""
117        prompts = [r.prompt for r in batch]
118        num_steps = batch[0].num_steps
119
120        # Run in thread pool to not block event loop
121        loop = asyncio.get_event_loop()
122        images = await loop.run_in_executor(
123            None,
124            lambda: self.pipe(
125                prompt=prompts,
126                num_inference_steps=num_steps,
127            ).images
128        )
129
130        return [{"image": img} for img in images]
131
132# Usage
133batcher = DynamicBatcher(pipe, max_batch_size=4, max_wait_time=0.5)
134
135# In your API endpoint
136async def generate(prompt: str):
137    return await batcher.submit(prompt, num_steps=20)

Batching Considerations

FactorImpactRecommendation
Different step countsCannot batch together efficientlyNormalize to standard step counts (4, 8, 20, 50)
Different resolutionsRequires padding or separate batchesOffer fixed resolution tiers
Memory constraintsLarger batches need more VRAMProfile and set max based on GPU
Latency requirementsWaiting for batch increases latencyTune max_wait_time based on SLOs

Scaling Infrastructure

For production workloads, you need to scale horizontally across multiple GPUs and nodes. Kubernetes is the standard platform for this.

Kubernetes Deployment

📄yaml
1# deployment.yaml
2apiVersion: apps/v1
3kind: Deployment
4metadata:
5  name: diffusion-worker
6  labels:
7    app: diffusion
8spec:
9  replicas: 3
10  selector:
11    matchLabels:
12      app: diffusion
13  template:
14    metadata:
15      labels:
16        app: diffusion
17    spec:
18      containers:
19      - name: worker
20        image: your-registry/diffusion-worker:latest
21        resources:
22          limits:
23            nvidia.com/gpu: 1
24            memory: "32Gi"
25            cpu: "8"
26          requests:
27            nvidia.com/gpu: 1
28            memory: "24Gi"
29            cpu: "4"
30        env:
31        - name: MODEL_PATH
32          value: "/models/sdxl-base"
33        - name: REDIS_URL
34          value: "redis://redis-master:6379"
35        volumeMounts:
36        - name: model-cache
37          mountPath: /models
38        - name: shm
39          mountPath: /dev/shm
40      volumes:
41      - name: model-cache
42        persistentVolumeClaim:
43          claimName: model-pvc
44      - name: shm
45        emptyDir:
46          medium: Memory
47          sizeLimit: "16Gi"
48      nodeSelector:
49        nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
50      tolerations:
51      - key: "nvidia.com/gpu"
52        operator: "Exists"
53        effect: "NoSchedule"
54---
55# Horizontal Pod Autoscaler
56apiVersion: autoscaling/v2
57kind: HorizontalPodAutoscaler
58metadata:
59  name: diffusion-worker-hpa
60spec:
61  scaleTargetRef:
62    apiVersion: apps/v1
63    kind: Deployment
64    name: diffusion-worker
65  minReplicas: 2
66  maxReplicas: 10
67  metrics:
68  - type: External
69    external:
70      metric:
71        name: redis_queue_length
72        selector:
73          matchLabels:
74            queue: generation_queue
75      target:
76        type: AverageValue
77        averageValue: "5"
78  behavior:
79    scaleUp:
80      stabilizationWindowSeconds: 60
81      policies:
82      - type: Pods
83        value: 2
84        periodSeconds: 60
85    scaleDown:
86      stabilizationWindowSeconds: 300
87      policies:
88      - type: Pods
89        value: 1
90        periodSeconds: 120

KEDA for Queue-Based Scaling

📄yaml
1# keda-scaler.yaml
2apiVersion: keda.sh/v1alpha1
3kind: ScaledObject
4metadata:
5  name: diffusion-worker-scaler
6spec:
7  scaleTargetRef:
8    name: diffusion-worker
9  minReplicaCount: 1
10  maxReplicaCount: 20
11  pollingInterval: 15
12  cooldownPeriod: 300
13  triggers:
14  - type: redis
15    metadata:
16      address: redis-master:6379
17      listName: generation_queue
18      listLength: "3"  # Scale up when 3+ jobs per worker
19  - type: prometheus
20    metadata:
21      serverAddress: http://prometheus:9090
22      metricName: gpu_utilization
23      threshold: "70"
24      query: |
25        avg(nvidia_gpu_utilization{job="diffusion-worker"})

GPU Scheduling Considerations

  • GPU Provisioning Time: New GPU nodes take 2-5 minutes to provision in cloud environments. Use predictive scaling.
  • Spot/Preemptible Instances: Can reduce costs by 60-70% but require handling interruptions gracefully.
  • Multi-GPU Nodes: A100 nodes often have 8 GPUs. Run multiple workers per node.
  • Model Caching: Use shared volumes or object storage to avoid downloading models on each scale-up.

Monitoring and Observability

Comprehensive monitoring is essential for maintaining SLOs and optimizing costs.

🐍python
1# metrics.py - Prometheus metrics for diffusion service
2from prometheus_client import Counter, Histogram, Gauge, start_http_server
3import time
4
5# Request metrics
6REQUESTS_TOTAL = Counter(
7    "diffusion_requests_total",
8    "Total generation requests",
9    ["status", "model"],
10)
11
12GENERATION_DURATION = Histogram(
13    "diffusion_generation_duration_seconds",
14    "Time to generate image",
15    ["model", "steps"],
16    buckets=[1, 2, 5, 10, 20, 30, 60, 120],
17)
18
19QUEUE_SIZE = Gauge(
20    "diffusion_queue_size",
21    "Number of jobs in queue",
22)
23
24# GPU metrics
25GPU_MEMORY_USED = Gauge(
26    "gpu_memory_used_bytes",
27    "GPU memory in use",
28    ["gpu_id"],
29)
30
31GPU_UTILIZATION = Gauge(
32    "gpu_utilization_percent",
33    "GPU compute utilization",
34    ["gpu_id"],
35)
36
37# Batch metrics
38BATCH_SIZE = Histogram(
39    "diffusion_batch_size",
40    "Number of requests per batch",
41    buckets=[1, 2, 4, 8, 16],
42)
43
44def record_generation(model: str, steps: int, duration: float, success: bool):
45    """Record generation metrics."""
46    status = "success" if success else "error"
47    REQUESTS_TOTAL.labels(status=status, model=model).inc()
48
49    if success:
50        GENERATION_DURATION.labels(model=model, steps=str(steps)).observe(duration)
51
52def update_gpu_metrics():
53    """Update GPU metrics (call periodically)."""
54    import pynvml
55    pynvml.nvmlInit()
56
57    device_count = pynvml.nvmlDeviceGetCount()
58    for i in range(device_count):
59        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
60
61        # Memory
62        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
63        GPU_MEMORY_USED.labels(gpu_id=str(i)).set(mem_info.used)
64
65        # Utilization
66        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
67        GPU_UTILIZATION.labels(gpu_id=str(i)).set(util.gpu)
68
69# Start metrics server
70start_http_server(8000)

Key Metrics to Monitor

MetricTargetAlert Threshold
p99 Latency< 30s> 60s for 5 min
GPU Utilization70-85%< 50% or > 95%
Queue Depth< 10 per worker> 50 total
Error Rate< 1%> 5%
Memory Usage< 90%> 95%
📄yaml
1# Grafana dashboard query examples
2
3# Average generation time by model
4avg(rate(diffusion_generation_duration_seconds_sum[5m]))
5  / rate(diffusion_generation_duration_seconds_count[5m])
6  by (model)
7
8# Queue depth over time
9diffusion_queue_size
10
11# GPU utilization heatmap
12avg(gpu_utilization_percent) by (gpu_id)
13
14# Request rate
15rate(diffusion_requests_total[1m])
16
17# Error rate percentage
18sum(rate(diffusion_requests_total{status="error"}[5m]))
19  / sum(rate(diffusion_requests_total[5m])) * 100

Cost Optimization

GPU costs dominate diffusion model serving expenses. Here are strategies to optimize:

Cost Breakdown

GPU TypeOn-Demand/hrSpot/hrImages/hr (20 steps)Cost/1000 Images
A10G$1.00$0.35~600$1.67 (on-demand)
A100 40GB$3.50$1.20~1200$2.92 (on-demand)
A100 80GB$4.50$1.50~1400$3.21 (on-demand)
H100$8.00$2.80~2500$3.20 (on-demand)

Optimization Strategies

  1. Use Spot Instances: 60-70% cost reduction. Handle interruptions with checkpointing and queue-based processing.
  2. Right-size GPUs: Match GPU to workload. A10G is often sufficient for SD 1.5/2.x; A100 needed for SDXL.
  3. Maximize Utilization: Target 70-85% GPU utilization. Use batching and queue management.
  4. Reduce Steps: Use LCM/Turbo models for 4-step generation (10x throughput increase).
  5. Time-based Scaling: Scale down during off-peak hours based on historical patterns.
🐍python
1# Cost-aware scheduler
2class CostAwareScheduler:
3    """Schedule jobs based on cost optimization."""
4
5    def __init__(self):
6        self.spot_workers = []  # Cheaper, may be interrupted
7        self.on_demand_workers = []  # Reliable, more expensive
8
9    def schedule(self, job, priority: str = "normal"):
10        """Route job to appropriate worker type."""
11        if priority == "high" or job.deadline_soon():
12            # Use on-demand for reliability
13            return self.on_demand_workers
14        else:
15            # Prefer spot for cost savings
16            if self.spot_workers:
17                return self.spot_workers
18            return self.on_demand_workers
19
20    def handle_spot_interruption(self, worker_id: str):
21        """Handle spot instance preemption."""
22        # Get current job
23        current_job = self.get_worker_job(worker_id)
24
25        if current_job:
26            # Re-queue the job
27            self.queue.push(current_job, priority="high")
28
29        # Remove worker from pool
30        self.spot_workers.remove(worker_id)

Summary

Serving diffusion models in production requires careful architecture and operational practices:

  1. Async Processing: Use queue-based architectures for long-running generation tasks
  2. Framework Choice: TorchServe for simplicity, Triton for high-throughput production
  3. API Design: Implement job submission, status polling, and WebSocket progress updates
  4. Dynamic Batching: Maximize GPU utilization by batching compatible requests
  5. Kubernetes Scaling: Use HPA or KEDA for queue-based auto-scaling
  6. Comprehensive Monitoring: Track latency, GPU utilization, queue depth, and error rates
  7. Cost Optimization: Use spot instances, right-sized GPUs, and distilled models
Looking Ahead: In the next chapter, we'll explore the future of diffusion models, including video generation, 3D generation, and other cutting-edge applications.