Introduction
A/B testing allows you to compare agent variants in production with real users and real tasks. Unlike offline benchmarks, experiments capture the full complexity of real-world usage and reveal issues that only emerge at scale.
Why Experiment? Offline benchmarks tell you how an agent performs on test data. Experiments tell you how it performs with real users, real edge cases, and real business outcomes.
This section covers how to design experiments for AI agents, implement traffic splitting, analyze results statistically, and use advanced techniques like multi-arm bandits for continuous optimization.
Experiment Design
Well-designed experiments isolate the variable being tested and control for confounding factors. Here's a framework for designing agent experiments:
1"""
2Experiment design framework for AI agents.
3"""
4
5from dataclasses import dataclass, field
6from datetime import datetime, timedelta
7from enum import Enum
8from typing import Any, Callable, Dict, List, Optional, Set
9import uuid
10
11
12class ExperimentStatus(Enum):
13 """Status of an experiment."""
14 DRAFT = "draft"
15 RUNNING = "running"
16 PAUSED = "paused"
17 COMPLETED = "completed"
18 ABORTED = "aborted"
19
20
21@dataclass
22class Variant:
23 """A variant in an experiment."""
24 id: str
25 name: str
26 description: str
27 config: Dict[str, Any]
28 traffic_percentage: float = 0.0
29
30 def to_dict(self) -> Dict[str, Any]:
31 return {
32 "id": self.id,
33 "name": self.name,
34 "description": self.description,
35 "config": self.config,
36 "traffic_percentage": self.traffic_percentage
37 }
38
39
40@dataclass
41class ExperimentMetric:
42 """Metric to track in an experiment."""
43 name: str
44 description: str
45 is_primary: bool = False
46 higher_is_better: bool = True
47 minimum_detectable_effect: float = 0.05 # 5% change
48
49
50@dataclass
51class Experiment:
52 """An A/B experiment for agent testing."""
53 id: str
54 name: str
55 description: str
56 variants: List[Variant]
57 metrics: List[ExperimentMetric]
58 status: ExperimentStatus = ExperimentStatus.DRAFT
59 start_time: Optional[datetime] = None
60 end_time: Optional[datetime] = None
61 target_sample_size: int = 1000
62 user_segments: Set[str] = field(default_factory=set)
63
64 # Guardrails
65 min_runtime_hours: int = 24
66 max_runtime_days: int = 14
67 auto_stop_on_degradation: bool = True
68 degradation_threshold: float = 0.10 # 10% worse
69
70 @property
71 def control_variant(self) -> Optional[Variant]:
72 """Get the control variant."""
73 for variant in self.variants:
74 if "control" in variant.name.lower():
75 return variant
76 return self.variants[0] if self.variants else None
77
78 @property
79 def treatment_variants(self) -> List[Variant]:
80 """Get treatment variants."""
81 control = self.control_variant
82 return [v for v in self.variants if v != control]
83
84 def validate(self) -> List[str]:
85 """Validate experiment configuration."""
86 errors = []
87
88 # Check variant traffic sums to 100%
89 total_traffic = sum(v.traffic_percentage for v in self.variants)
90 if abs(total_traffic - 100.0) > 0.01:
91 errors.append(f"Traffic percentages sum to {total_traffic}%, not 100%")
92
93 # Check for primary metric
94 primary_metrics = [m for m in self.metrics if m.is_primary]
95 if not primary_metrics:
96 errors.append("No primary metric defined")
97
98 # Check variants
99 if len(self.variants) < 2:
100 errors.append("At least 2 variants required")
101
102 # Check unique IDs
103 ids = [v.id for v in self.variants]
104 if len(ids) != len(set(ids)):
105 errors.append("Variant IDs must be unique")
106
107 return errors
108
109 def to_dict(self) -> Dict[str, Any]:
110 return {
111 "id": self.id,
112 "name": self.name,
113 "description": self.description,
114 "status": self.status.value,
115 "variants": [v.to_dict() for v in self.variants],
116 "metrics": [
117 {
118 "name": m.name,
119 "is_primary": m.is_primary,
120 "higher_is_better": m.higher_is_better
121 }
122 for m in self.metrics
123 ],
124 "start_time": self.start_time.isoformat() if self.start_time else None,
125 "end_time": self.end_time.isoformat() if self.end_time else None,
126 "target_sample_size": self.target_sample_size
127 }
128
129
130class ExperimentBuilder:
131 """Builder for creating experiments."""
132
133 def __init__(self, name: str):
134 self.experiment = Experiment(
135 id=str(uuid.uuid4())[:8],
136 name=name,
137 description="",
138 variants=[],
139 metrics=[]
140 )
141
142 def with_description(self, description: str) -> "ExperimentBuilder":
143 self.experiment.description = description
144 return self
145
146 def add_control(
147 self,
148 name: str,
149 config: Dict[str, Any],
150 traffic: float = 50.0
151 ) -> "ExperimentBuilder":
152 """Add control variant."""
153 self.experiment.variants.append(Variant(
154 id="control",
155 name=name,
156 description="Control variant",
157 config=config,
158 traffic_percentage=traffic
159 ))
160 return self
161
162 def add_treatment(
163 self,
164 name: str,
165 config: Dict[str, Any],
166 traffic: float = 50.0
167 ) -> "ExperimentBuilder":
168 """Add treatment variant."""
169 variant_id = f"treatment_{len(self.experiment.treatment_variants)}"
170 self.experiment.variants.append(Variant(
171 id=variant_id,
172 name=name,
173 description=f"Treatment: {name}",
174 config=config,
175 traffic_percentage=traffic
176 ))
177 return self
178
179 def add_metric(
180 self,
181 name: str,
182 description: str = "",
183 is_primary: bool = False,
184 higher_is_better: bool = True
185 ) -> "ExperimentBuilder":
186 """Add a metric to track."""
187 self.experiment.metrics.append(ExperimentMetric(
188 name=name,
189 description=description,
190 is_primary=is_primary,
191 higher_is_better=higher_is_better
192 ))
193 return self
194
195 def with_sample_size(self, size: int) -> "ExperimentBuilder":
196 self.experiment.target_sample_size = size
197 return self
198
199 def with_segments(self, segments: List[str]) -> "ExperimentBuilder":
200 self.experiment.user_segments = set(segments)
201 return self
202
203 def build(self) -> Experiment:
204 """Build and validate the experiment."""
205 errors = self.experiment.validate()
206 if errors:
207 raise ValueError(f"Invalid experiment: {errors}")
208 return self.experiment
209
210
211# Example: Create an experiment comparing agent versions
212def create_agent_version_experiment() -> Experiment:
213 return (
214 ExperimentBuilder("Agent V2 Rollout")
215 .with_description("Compare new agent v2 against v1")
216 .add_control(
217 name="Agent V1 (Current)",
218 config={"agent_version": "1.0", "model": "gpt-4"},
219 traffic=50.0
220 )
221 .add_treatment(
222 name="Agent V2 (New)",
223 config={"agent_version": "2.0", "model": "gpt-4"},
224 traffic=50.0
225 )
226 .add_metric("task_success_rate", is_primary=True, higher_is_better=True)
227 .add_metric("avg_latency_ms", higher_is_better=False)
228 .add_metric("user_satisfaction", higher_is_better=True)
229 .add_metric("cost_per_task", higher_is_better=False)
230 .with_sample_size(2000)
231 .build()
232 )Traffic Splitting
Traffic splitting assigns users to experiment variants consistently. Here's how to implement robust traffic allocation:
1"""
2Traffic splitting for A/B experiments.
3"""
4
5import hashlib
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Tuple
8
9
10@dataclass
11class AllocationResult:
12 """Result of traffic allocation."""
13 experiment_id: str
14 variant_id: str
15 variant_name: str
16 config: Dict[str, Any]
17 is_in_experiment: bool
18
19
20class TrafficAllocator:
21 """Allocates traffic to experiment variants."""
22
23 def __init__(self, salt: str = "experiment_salt"):
24 self.salt = salt
25 self.experiments: Dict[str, Experiment] = {}
26 self.overrides: Dict[str, Dict[str, str]] = {} # user_id -> experiment -> variant
27
28 def register_experiment(self, experiment: Experiment):
29 """Register an experiment."""
30 self.experiments[experiment.id] = experiment
31
32 def add_override(
33 self,
34 user_id: str,
35 experiment_id: str,
36 variant_id: str
37 ):
38 """Add a user override for testing."""
39 if user_id not in self.overrides:
40 self.overrides[user_id] = {}
41 self.overrides[user_id][experiment_id] = variant_id
42
43 def allocate(
44 self,
45 user_id: str,
46 experiment_id: str,
47 user_segments: Optional[List[str]] = None
48 ) -> AllocationResult:
49 """Allocate a user to a variant."""
50
51 experiment = self.experiments.get(experiment_id)
52
53 if not experiment:
54 return AllocationResult(
55 experiment_id=experiment_id,
56 variant_id="",
57 variant_name="",
58 config={},
59 is_in_experiment=False
60 )
61
62 # Check if experiment is running
63 if experiment.status != ExperimentStatus.RUNNING:
64 return self._default_allocation(experiment)
65
66 # Check for override
67 if user_id in self.overrides:
68 override_variant = self.overrides[user_id].get(experiment_id)
69 if override_variant:
70 variant = next(
71 (v for v in experiment.variants if v.id == override_variant),
72 None
73 )
74 if variant:
75 return AllocationResult(
76 experiment_id=experiment_id,
77 variant_id=variant.id,
78 variant_name=variant.name,
79 config=variant.config,
80 is_in_experiment=True
81 )
82
83 # Check segment eligibility
84 if experiment.user_segments:
85 user_segment_set = set(user_segments or [])
86 if not (user_segment_set & experiment.user_segments):
87 return self._default_allocation(experiment)
88
89 # Hash-based allocation for consistency
90 variant = self._hash_allocate(user_id, experiment)
91
92 return AllocationResult(
93 experiment_id=experiment_id,
94 variant_id=variant.id,
95 variant_name=variant.name,
96 config=variant.config,
97 is_in_experiment=True
98 )
99
100 def _hash_allocate(
101 self,
102 user_id: str,
103 experiment: Experiment
104 ) -> Variant:
105 """Allocate using consistent hashing."""
106
107 # Create hash from user_id + experiment_id
108 hash_input = f"{self.salt}:{experiment.id}:{user_id}"
109 hash_bytes = hashlib.sha256(hash_input.encode()).digest()
110
111 # Convert to bucket (0-99.99)
112 bucket = (int.from_bytes(hash_bytes[:4], "big") % 10000) / 100.0
113
114 # Find variant for bucket
115 cumulative = 0.0
116 for variant in experiment.variants:
117 cumulative += variant.traffic_percentage
118 if bucket < cumulative:
119 return variant
120
121 # Fallback to last variant
122 return experiment.variants[-1]
123
124 def _default_allocation(self, experiment: Experiment) -> AllocationResult:
125 """Return default (control) allocation."""
126 control = experiment.control_variant
127 if control:
128 return AllocationResult(
129 experiment_id=experiment.id,
130 variant_id=control.id,
131 variant_name=control.name,
132 config=control.config,
133 is_in_experiment=False
134 )
135 return AllocationResult(
136 experiment_id=experiment.id,
137 variant_id="",
138 variant_name="",
139 config={},
140 is_in_experiment=False
141 )
142
143
144class StickyAllocation:
145 """Maintains sticky allocation across sessions."""
146
147 def __init__(self, storage, allocator: TrafficAllocator):
148 self.storage = storage # Redis, database, etc.
149 self.allocator = allocator
150
151 async def get_allocation(
152 self,
153 user_id: str,
154 experiment_id: str,
155 user_segments: Optional[List[str]] = None
156 ) -> AllocationResult:
157 """Get or create sticky allocation."""
158
159 # Check for existing allocation
160 key = f"allocation:{experiment_id}:{user_id}"
161 existing = await self.storage.get(key)
162
163 if existing:
164 experiment = self.allocator.experiments.get(experiment_id)
165 if experiment:
166 variant = next(
167 (v for v in experiment.variants if v.id == existing),
168 None
169 )
170 if variant:
171 return AllocationResult(
172 experiment_id=experiment_id,
173 variant_id=variant.id,
174 variant_name=variant.name,
175 config=variant.config,
176 is_in_experiment=True
177 )
178
179 # Create new allocation
180 allocation = self.allocator.allocate(
181 user_id, experiment_id, user_segments
182 )
183
184 # Store for stickiness
185 if allocation.is_in_experiment:
186 await self.storage.set(key, allocation.variant_id)
187
188 return allocationStatistical Analysis
Proper statistical analysis determines whether observed differences are significant or due to chance:
1"""
2Statistical analysis for A/B experiments.
3"""
4
5import math
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Tuple
8import statistics
9
10
11@dataclass
12class VariantStats:
13 """Statistics for a variant."""
14 variant_id: str
15 sample_size: int
16 mean: float
17 std_dev: float
18 median: float
19 p25: float
20 p75: float
21 p95: float
22
23
24@dataclass
25class ExperimentResult:
26 """Statistical analysis result."""
27 experiment_id: str
28 metric_name: str
29 control_stats: VariantStats
30 treatment_stats: VariantStats
31 absolute_difference: float
32 relative_difference: float
33 p_value: float
34 confidence_interval: Tuple[float, float]
35 is_significant: bool
36 power: float
37 recommendation: str
38
39
40class ExperimentAnalyzer:
41 """Analyzes experiment results statistically."""
42
43 def __init__(self, significance_level: float = 0.05):
44 self.significance_level = significance_level
45
46 def analyze(
47 self,
48 experiment: Experiment,
49 control_values: List[float],
50 treatment_values: List[float],
51 metric_name: str
52 ) -> ExperimentResult:
53 """Perform statistical analysis on experiment data."""
54
55 # Calculate basic statistics
56 control_stats = self._calculate_stats("control", control_values)
57 treatment_stats = self._calculate_stats("treatment", treatment_values)
58
59 # Calculate differences
60 absolute_diff = treatment_stats.mean - control_stats.mean
61 relative_diff = (
62 absolute_diff / control_stats.mean * 100
63 if control_stats.mean != 0 else 0
64 )
65
66 # Perform t-test
67 p_value = self._welch_t_test(control_values, treatment_values)
68
69 # Calculate confidence interval
70 ci = self._confidence_interval(
71 control_values, treatment_values,
72 1 - self.significance_level
73 )
74
75 # Determine significance
76 is_significant = p_value < self.significance_level
77
78 # Calculate power
79 power = self._calculate_power(
80 control_stats, treatment_stats,
81 self.significance_level
82 )
83
84 # Generate recommendation
85 metric = next(
86 (m for m in experiment.metrics if m.name == metric_name),
87 None
88 )
89 recommendation = self._generate_recommendation(
90 is_significant, relative_diff, power,
91 metric.higher_is_better if metric else True
92 )
93
94 return ExperimentResult(
95 experiment_id=experiment.id,
96 metric_name=metric_name,
97 control_stats=control_stats,
98 treatment_stats=treatment_stats,
99 absolute_difference=absolute_diff,
100 relative_difference=relative_diff,
101 p_value=p_value,
102 confidence_interval=ci,
103 is_significant=is_significant,
104 power=power,
105 recommendation=recommendation
106 )
107
108 def _calculate_stats(
109 self,
110 variant_id: str,
111 values: List[float]
112 ) -> VariantStats:
113 """Calculate descriptive statistics."""
114
115 if not values:
116 return VariantStats(
117 variant_id=variant_id,
118 sample_size=0,
119 mean=0, std_dev=0, median=0,
120 p25=0, p75=0, p95=0
121 )
122
123 sorted_values = sorted(values)
124 n = len(values)
125
126 return VariantStats(
127 variant_id=variant_id,
128 sample_size=n,
129 mean=statistics.mean(values),
130 std_dev=statistics.stdev(values) if n > 1 else 0,
131 median=statistics.median(values),
132 p25=sorted_values[int(n * 0.25)],
133 p75=sorted_values[int(n * 0.75)],
134 p95=sorted_values[int(n * 0.95)]
135 )
136
137 def _welch_t_test(
138 self,
139 control: List[float],
140 treatment: List[float]
141 ) -> float:
142 """Perform Welch's t-test."""
143
144 n1, n2 = len(control), len(treatment)
145 if n1 < 2 or n2 < 2:
146 return 1.0
147
148 mean1 = statistics.mean(control)
149 mean2 = statistics.mean(treatment)
150 var1 = statistics.variance(control)
151 var2 = statistics.variance(treatment)
152
153 # Welch's t-statistic
154 se = math.sqrt(var1/n1 + var2/n2)
155 if se == 0:
156 return 1.0
157
158 t_stat = abs(mean2 - mean1) / se
159
160 # Welch-Satterthwaite degrees of freedom
161 num = (var1/n1 + var2/n2) ** 2
162 denom = (var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1)
163 df = num / denom if denom > 0 else 1
164
165 # Approximate p-value using normal distribution for large df
166 if df > 30:
167 # Use normal approximation
168 p_value = 2 * (1 - self._normal_cdf(t_stat))
169 else:
170 # Use t-distribution approximation
171 p_value = 2 * (1 - self._t_cdf(t_stat, df))
172
173 return p_value
174
175 def _normal_cdf(self, x: float) -> float:
176 """Standard normal CDF approximation."""
177 return 0.5 * (1 + math.erf(x / math.sqrt(2)))
178
179 def _t_cdf(self, t: float, df: float) -> float:
180 """Student's t CDF approximation."""
181 # Use normal approximation for simplicity
182 return self._normal_cdf(t)
183
184 def _confidence_interval(
185 self,
186 control: List[float],
187 treatment: List[float],
188 confidence: float
189 ) -> Tuple[float, float]:
190 """Calculate confidence interval for difference."""
191
192 n1, n2 = len(control), len(treatment)
193 if n1 < 2 or n2 < 2:
194 return (0, 0)
195
196 mean_diff = statistics.mean(treatment) - statistics.mean(control)
197 var1 = statistics.variance(control)
198 var2 = statistics.variance(treatment)
199
200 se = math.sqrt(var1/n1 + var2/n2)
201
202 # Z-score for confidence level
203 z = self._z_score(confidence)
204
205 margin = z * se
206
207 return (mean_diff - margin, mean_diff + margin)
208
209 def _z_score(self, confidence: float) -> float:
210 """Get z-score for confidence level."""
211 # Common values
212 z_scores = {
213 0.90: 1.645,
214 0.95: 1.96,
215 0.99: 2.576
216 }
217 return z_scores.get(confidence, 1.96)
218
219 def _calculate_power(
220 self,
221 control: VariantStats,
222 treatment: VariantStats,
223 alpha: float
224 ) -> float:
225 """Calculate statistical power."""
226
227 if control.sample_size < 2 or treatment.sample_size < 2:
228 return 0.0
229
230 # Pooled standard deviation
231 pooled_std = math.sqrt(
232 (control.std_dev ** 2 + treatment.std_dev ** 2) / 2
233 )
234
235 if pooled_std == 0:
236 return 1.0
237
238 # Effect size (Cohen's d)
239 effect_size = abs(treatment.mean - control.mean) / pooled_std
240
241 # Approximate power calculation
242 n = min(control.sample_size, treatment.sample_size)
243 ncp = effect_size * math.sqrt(n / 2) # Non-centrality parameter
244
245 z_alpha = self._z_score(1 - alpha / 2)
246 power = self._normal_cdf(ncp - z_alpha)
247
248 return min(1.0, power)
249
250 def _generate_recommendation(
251 self,
252 is_significant: bool,
253 relative_diff: float,
254 power: float,
255 higher_is_better: bool
256 ) -> str:
257 """Generate recommendation based on analysis."""
258
259 if not is_significant:
260 if power < 0.8:
261 return "Inconclusive - need more data for statistical power"
262 else:
263 return "No significant difference detected"
264
265 if higher_is_better:
266 if relative_diff > 0:
267 return f"Treatment is {relative_diff:.1f}% better - recommend rollout"
268 else:
269 return f"Treatment is {abs(relative_diff):.1f}% worse - recommend keeping control"
270 else:
271 if relative_diff < 0:
272 return f"Treatment is {abs(relative_diff):.1f}% better - recommend rollout"
273 else:
274 return f"Treatment is {relative_diff:.1f}% worse - recommend keeping control"Feature Flags
Feature flags enable gradual rollouts and instant rollbacks. Here's how to integrate them with experiments:
1"""
2Feature flag integration for experiments.
3"""
4
5from dataclasses import dataclass, field
6from datetime import datetime
7from enum import Enum
8from typing import Any, Callable, Dict, List, Optional
9
10
11class RolloutStage(Enum):
12 """Stages of feature rollout."""
13 OFF = "off"
14 INTERNAL = "internal" # Internal testing
15 CANARY = "canary" # Small percentage
16 BETA = "beta" # Beta users
17 GRADUAL = "gradual" # Percentage rollout
18 FULL = "full" # 100% rollout
19
20
21@dataclass
22class FeatureFlag:
23 """A feature flag configuration."""
24 id: str
25 name: str
26 description: str
27 stage: RolloutStage = RolloutStage.OFF
28 percentage: float = 0.0
29 allowed_users: List[str] = field(default_factory=list)
30 allowed_segments: List[str] = field(default_factory=list)
31 config: Dict[str, Any] = field(default_factory=dict)
32 experiment_id: Optional[str] = None # Link to experiment
33
34 def is_enabled_for(
35 self,
36 user_id: str,
37 user_segments: List[str] = None,
38 allocator: Optional["TrafficAllocator"] = None
39 ) -> bool:
40 """Check if flag is enabled for a user."""
41
42 if self.stage == RolloutStage.OFF:
43 return False
44
45 if self.stage == RolloutStage.FULL:
46 return True
47
48 # Check allowed users
49 if user_id in self.allowed_users:
50 return True
51
52 # Check segments
53 if user_segments:
54 if set(user_segments) & set(self.allowed_segments):
55 return True
56
57 # Check percentage rollout
58 if self.stage in (RolloutStage.CANARY, RolloutStage.GRADUAL):
59 if allocator and self.experiment_id:
60 allocation = allocator.allocate(user_id, self.experiment_id)
61 return allocation.is_in_experiment and allocation.variant_id != "control"
62 else:
63 # Simple hash-based check
64 bucket = self._get_bucket(user_id)
65 return bucket < self.percentage
66
67 return False
68
69 def _get_bucket(self, user_id: str) -> float:
70 """Get bucket for user (0-100)."""
71 import hashlib
72 hash_bytes = hashlib.sha256(
73 f"{self.id}:{user_id}".encode()
74 ).digest()
75 return (int.from_bytes(hash_bytes[:4], "big") % 10000) / 100.0
76
77
78class FeatureFlagManager:
79 """Manages feature flags."""
80
81 def __init__(self, allocator: TrafficAllocator = None):
82 self.flags: Dict[str, FeatureFlag] = {}
83 self.allocator = allocator
84
85 def register(self, flag: FeatureFlag):
86 """Register a feature flag."""
87 self.flags[flag.id] = flag
88
89 def is_enabled(
90 self,
91 flag_id: str,
92 user_id: str,
93 user_segments: List[str] = None,
94 default: bool = False
95 ) -> bool:
96 """Check if a flag is enabled."""
97
98 flag = self.flags.get(flag_id)
99 if not flag:
100 return default
101
102 return flag.is_enabled_for(
103 user_id, user_segments, self.allocator
104 )
105
106 def get_config(
107 self,
108 flag_id: str,
109 user_id: str,
110 user_segments: List[str] = None
111 ) -> Dict[str, Any]:
112 """Get config for a flag if enabled."""
113
114 flag = self.flags.get(flag_id)
115 if not flag:
116 return {}
117
118 if flag.is_enabled_for(user_id, user_segments, self.allocator):
119 # If linked to experiment, get variant config
120 if flag.experiment_id and self.allocator:
121 allocation = self.allocator.allocate(
122 user_id, flag.experiment_id
123 )
124 return {**flag.config, **allocation.config}
125 return flag.config
126
127 return {}
128
129 def set_stage(self, flag_id: str, stage: RolloutStage, percentage: float = 0.0):
130 """Change rollout stage."""
131
132 flag = self.flags.get(flag_id)
133 if flag:
134 flag.stage = stage
135 flag.percentage = percentage
136
137 def rollback(self, flag_id: str):
138 """Instantly disable a flag."""
139 self.set_stage(flag_id, RolloutStage.OFF)
140
141
142class GradualRollout:
143 """Manages gradual feature rollout."""
144
145 def __init__(
146 self,
147 flag_manager: FeatureFlagManager,
148 experiment_analyzer: ExperimentAnalyzer
149 ):
150 self.flag_manager = flag_manager
151 self.analyzer = experiment_analyzer
152 self.rollout_schedule: Dict[str, List[Tuple[datetime, float]]] = {}
153
154 def start_rollout(
155 self,
156 flag_id: str,
157 schedule: List[Tuple[int, float]] # (hours_from_start, percentage)
158 ):
159 """Start a scheduled gradual rollout."""
160
161 start = datetime.utcnow()
162 self.rollout_schedule[flag_id] = [
163 (start + timedelta(hours=hours), pct)
164 for hours, pct in schedule
165 ]
166
167 # Set initial stage
168 self.flag_manager.set_stage(
169 flag_id,
170 RolloutStage.GRADUAL,
171 schedule[0][1] if schedule else 0
172 )
173
174 async def check_and_advance(
175 self,
176 flag_id: str,
177 control_data: List[float],
178 treatment_data: List[float]
179 ) -> Dict[str, Any]:
180 """Check metrics and potentially advance rollout."""
181
182 flag = self.flag_manager.flags.get(flag_id)
183 if not flag:
184 return {"status": "flag_not_found"}
185
186 # Check current schedule
187 schedule = self.rollout_schedule.get(flag_id, [])
188 now = datetime.utcnow()
189
190 # Find next stage
191 next_stage = None
192 for scheduled_time, percentage in schedule:
193 if scheduled_time <= now and percentage > flag.percentage:
194 next_stage = percentage
195
196 if not next_stage:
197 return {"status": "no_stage_change"}
198
199 # Analyze current metrics
200 if flag.experiment_id:
201 experiment = self.flag_manager.allocator.experiments.get(
202 flag.experiment_id
203 )
204 if experiment:
205 result = self.analyzer.analyze(
206 experiment, control_data, treatment_data,
207 experiment.metrics[0].name
208 )
209
210 # Check for degradation
211 if result.is_significant and result.relative_difference < -10:
212 # Rollback!
213 self.flag_manager.rollback(flag_id)
214 return {
215 "status": "rollback",
216 "reason": f"Significant degradation: {result.relative_difference:.1f}%"
217 }
218
219 # Advance rollout
220 self.flag_manager.set_stage(flag_id, RolloutStage.GRADUAL, next_stage)
221
222 return {
223 "status": "advanced",
224 "new_percentage": next_stage
225 }Multi-Arm Bandits
Multi-arm bandits automatically optimize traffic allocation by directing more users to better-performing variants:
1"""
2Multi-arm bandit algorithms for adaptive experiments.
3"""
4
5import math
6import random
7from abc import ABC, abstractmethod
8from dataclasses import dataclass, field
9from typing import Dict, List, Optional, Tuple
10
11
12@dataclass
13class ArmStats:
14 """Statistics for a bandit arm."""
15 arm_id: str
16 pulls: int = 0
17 total_reward: float = 0.0
18 successes: int = 0
19 failures: int = 0
20
21 @property
22 def mean_reward(self) -> float:
23 return self.total_reward / self.pulls if self.pulls > 0 else 0.0
24
25 @property
26 def success_rate(self) -> float:
27 total = self.successes + self.failures
28 return self.successes / total if total > 0 else 0.5
29
30
31class BanditAlgorithm(ABC):
32 """Abstract base for bandit algorithms."""
33
34 @abstractmethod
35 def select_arm(self, arms: Dict[str, ArmStats]) -> str:
36 """Select an arm to pull."""
37 pass
38
39 @abstractmethod
40 def update(self, arm_id: str, reward: float):
41 """Update arm statistics after observation."""
42 pass
43
44
45class EpsilonGreedy(BanditAlgorithm):
46 """Epsilon-greedy algorithm."""
47
48 def __init__(self, epsilon: float = 0.1):
49 self.epsilon = epsilon
50 self.arms: Dict[str, ArmStats] = {}
51
52 def add_arm(self, arm_id: str):
53 """Add a new arm."""
54 self.arms[arm_id] = ArmStats(arm_id=arm_id)
55
56 def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
57 """Select arm: random with probability epsilon, best otherwise."""
58 arms = arms or self.arms
59
60 if not arms:
61 raise ValueError("No arms available")
62
63 # Explore with probability epsilon
64 if random.random() < self.epsilon:
65 return random.choice(list(arms.keys()))
66
67 # Exploit: select arm with highest mean reward
68 best_arm = max(arms, key=lambda a: arms[a].mean_reward)
69 return best_arm
70
71 def update(self, arm_id: str, reward: float):
72 """Update arm statistics."""
73 if arm_id in self.arms:
74 self.arms[arm_id].pulls += 1
75 self.arms[arm_id].total_reward += reward
76
77
78class UCB1(BanditAlgorithm):
79 """Upper Confidence Bound algorithm."""
80
81 def __init__(self, exploration_factor: float = 2.0):
82 self.exploration_factor = exploration_factor
83 self.arms: Dict[str, ArmStats] = {}
84 self.total_pulls = 0
85
86 def add_arm(self, arm_id: str):
87 """Add a new arm."""
88 self.arms[arm_id] = ArmStats(arm_id=arm_id)
89
90 def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
91 """Select arm with highest upper confidence bound."""
92 arms = arms or self.arms
93
94 if not arms:
95 raise ValueError("No arms available")
96
97 # Pull each arm at least once
98 for arm_id, stats in arms.items():
99 if stats.pulls == 0:
100 return arm_id
101
102 # Calculate UCB for each arm
103 ucb_values = {}
104 for arm_id, stats in arms.items():
105 exploitation = stats.mean_reward
106 exploration = math.sqrt(
107 self.exploration_factor * math.log(self.total_pulls) / stats.pulls
108 )
109 ucb_values[arm_id] = exploitation + exploration
110
111 return max(ucb_values, key=ucb_values.get)
112
113 def update(self, arm_id: str, reward: float):
114 """Update arm statistics."""
115 if arm_id in self.arms:
116 self.arms[arm_id].pulls += 1
117 self.arms[arm_id].total_reward += reward
118 self.total_pulls += 1
119
120
121class ThompsonSampling(BanditAlgorithm):
122 """Thompson Sampling for Bernoulli bandits."""
123
124 def __init__(self, prior_alpha: float = 1.0, prior_beta: float = 1.0):
125 self.prior_alpha = prior_alpha
126 self.prior_beta = prior_beta
127 self.arms: Dict[str, ArmStats] = {}
128
129 def add_arm(self, arm_id: str):
130 """Add a new arm."""
131 self.arms[arm_id] = ArmStats(arm_id=arm_id)
132
133 def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
134 """Select arm by sampling from posterior distributions."""
135 arms = arms or self.arms
136
137 if not arms:
138 raise ValueError("No arms available")
139
140 # Sample from Beta distribution for each arm
141 samples = {}
142 for arm_id, stats in arms.items():
143 alpha = self.prior_alpha + stats.successes
144 beta = self.prior_beta + stats.failures
145 samples[arm_id] = random.betavariate(alpha, beta)
146
147 return max(samples, key=samples.get)
148
149 def update(self, arm_id: str, reward: float):
150 """Update arm statistics (reward should be 0 or 1)."""
151 if arm_id in self.arms:
152 self.arms[arm_id].pulls += 1
153 if reward > 0.5: # Success threshold
154 self.arms[arm_id].successes += 1
155 else:
156 self.arms[arm_id].failures += 1
157 self.arms[arm_id].total_reward += reward
158
159
160class AdaptiveExperiment:
161 """Experiment using multi-arm bandit for adaptive allocation."""
162
163 def __init__(
164 self,
165 experiment: Experiment,
166 algorithm: BanditAlgorithm
167 ):
168 self.experiment = experiment
169 self.algorithm = algorithm
170
171 # Initialize arms
172 for variant in experiment.variants:
173 self.algorithm.add_arm(variant.id)
174
175 def allocate(self, user_id: str) -> Variant:
176 """Allocate user using bandit algorithm."""
177 selected_arm = self.algorithm.select_arm(self.algorithm.arms)
178
179 variant = next(
180 (v for v in self.experiment.variants if v.id == selected_arm),
181 self.experiment.variants[0]
182 )
183
184 return variant
185
186 def record_outcome(self, variant_id: str, success: bool, reward: float = None):
187 """Record experiment outcome."""
188 if reward is None:
189 reward = 1.0 if success else 0.0
190
191 self.algorithm.update(variant_id, reward)
192
193 def get_allocation_percentages(self) -> Dict[str, float]:
194 """Get current implied allocation percentages."""
195 total_pulls = sum(
196 self.algorithm.arms[arm_id].pulls
197 for arm_id in self.algorithm.arms
198 )
199
200 if total_pulls == 0:
201 n = len(self.algorithm.arms)
202 return {arm_id: 100.0/n for arm_id in self.algorithm.arms}
203
204 return {
205 arm_id: (stats.pulls / total_pulls) * 100
206 for arm_id, stats in self.algorithm.arms.items()
207 }Summary
This section covered the key components of A/B testing for AI agents:
| Topic | Key Concepts |
|---|---|
| Experiment Design | Variants, metrics, sample size, guardrails |
| Traffic Splitting | Hash-based allocation, stickiness, overrides |
| Statistical Analysis | T-tests, confidence intervals, power analysis |
| Feature Flags | Gradual rollout, instant rollback, segments |
| Multi-Arm Bandits | Epsilon-greedy, UCB1, Thompson Sampling |
Key Takeaways: A/B testing provides real-world validation of agent improvements. Use proper statistical analysis to avoid false conclusions, and consider multi-arm bandits for continuous optimization.
In the next section, we'll explore continuous evaluation systems that automatically track agent performance over time.