Chapter 20
15 min read
Section 126 of 175

A/B Testing and Experiments

Evaluation and Benchmarking

Introduction

A/B testing allows you to compare agent variants in production with real users and real tasks. Unlike offline benchmarks, experiments capture the full complexity of real-world usage and reveal issues that only emerge at scale.

Why Experiment? Offline benchmarks tell you how an agent performs on test data. Experiments tell you how it performs with real users, real edge cases, and real business outcomes.

This section covers how to design experiments for AI agents, implement traffic splitting, analyze results statistically, and use advanced techniques like multi-arm bandits for continuous optimization.


Experiment Design

Well-designed experiments isolate the variable being tested and control for confounding factors. Here's a framework for designing agent experiments:

🐍python
1"""
2Experiment design framework for AI agents.
3"""
4
5from dataclasses import dataclass, field
6from datetime import datetime, timedelta
7from enum import Enum
8from typing import Any, Callable, Dict, List, Optional, Set
9import uuid
10
11
12class ExperimentStatus(Enum):
13    """Status of an experiment."""
14    DRAFT = "draft"
15    RUNNING = "running"
16    PAUSED = "paused"
17    COMPLETED = "completed"
18    ABORTED = "aborted"
19
20
21@dataclass
22class Variant:
23    """A variant in an experiment."""
24    id: str
25    name: str
26    description: str
27    config: Dict[str, Any]
28    traffic_percentage: float = 0.0
29
30    def to_dict(self) -> Dict[str, Any]:
31        return {
32            "id": self.id,
33            "name": self.name,
34            "description": self.description,
35            "config": self.config,
36            "traffic_percentage": self.traffic_percentage
37        }
38
39
40@dataclass
41class ExperimentMetric:
42    """Metric to track in an experiment."""
43    name: str
44    description: str
45    is_primary: bool = False
46    higher_is_better: bool = True
47    minimum_detectable_effect: float = 0.05  # 5% change
48
49
50@dataclass
51class Experiment:
52    """An A/B experiment for agent testing."""
53    id: str
54    name: str
55    description: str
56    variants: List[Variant]
57    metrics: List[ExperimentMetric]
58    status: ExperimentStatus = ExperimentStatus.DRAFT
59    start_time: Optional[datetime] = None
60    end_time: Optional[datetime] = None
61    target_sample_size: int = 1000
62    user_segments: Set[str] = field(default_factory=set)
63
64    # Guardrails
65    min_runtime_hours: int = 24
66    max_runtime_days: int = 14
67    auto_stop_on_degradation: bool = True
68    degradation_threshold: float = 0.10  # 10% worse
69
70    @property
71    def control_variant(self) -> Optional[Variant]:
72        """Get the control variant."""
73        for variant in self.variants:
74            if "control" in variant.name.lower():
75                return variant
76        return self.variants[0] if self.variants else None
77
78    @property
79    def treatment_variants(self) -> List[Variant]:
80        """Get treatment variants."""
81        control = self.control_variant
82        return [v for v in self.variants if v != control]
83
84    def validate(self) -> List[str]:
85        """Validate experiment configuration."""
86        errors = []
87
88        # Check variant traffic sums to 100%
89        total_traffic = sum(v.traffic_percentage for v in self.variants)
90        if abs(total_traffic - 100.0) > 0.01:
91            errors.append(f"Traffic percentages sum to {total_traffic}%, not 100%")
92
93        # Check for primary metric
94        primary_metrics = [m for m in self.metrics if m.is_primary]
95        if not primary_metrics:
96            errors.append("No primary metric defined")
97
98        # Check variants
99        if len(self.variants) < 2:
100            errors.append("At least 2 variants required")
101
102        # Check unique IDs
103        ids = [v.id for v in self.variants]
104        if len(ids) != len(set(ids)):
105            errors.append("Variant IDs must be unique")
106
107        return errors
108
109    def to_dict(self) -> Dict[str, Any]:
110        return {
111            "id": self.id,
112            "name": self.name,
113            "description": self.description,
114            "status": self.status.value,
115            "variants": [v.to_dict() for v in self.variants],
116            "metrics": [
117                {
118                    "name": m.name,
119                    "is_primary": m.is_primary,
120                    "higher_is_better": m.higher_is_better
121                }
122                for m in self.metrics
123            ],
124            "start_time": self.start_time.isoformat() if self.start_time else None,
125            "end_time": self.end_time.isoformat() if self.end_time else None,
126            "target_sample_size": self.target_sample_size
127        }
128
129
130class ExperimentBuilder:
131    """Builder for creating experiments."""
132
133    def __init__(self, name: str):
134        self.experiment = Experiment(
135            id=str(uuid.uuid4())[:8],
136            name=name,
137            description="",
138            variants=[],
139            metrics=[]
140        )
141
142    def with_description(self, description: str) -> "ExperimentBuilder":
143        self.experiment.description = description
144        return self
145
146    def add_control(
147        self,
148        name: str,
149        config: Dict[str, Any],
150        traffic: float = 50.0
151    ) -> "ExperimentBuilder":
152        """Add control variant."""
153        self.experiment.variants.append(Variant(
154            id="control",
155            name=name,
156            description="Control variant",
157            config=config,
158            traffic_percentage=traffic
159        ))
160        return self
161
162    def add_treatment(
163        self,
164        name: str,
165        config: Dict[str, Any],
166        traffic: float = 50.0
167    ) -> "ExperimentBuilder":
168        """Add treatment variant."""
169        variant_id = f"treatment_{len(self.experiment.treatment_variants)}"
170        self.experiment.variants.append(Variant(
171            id=variant_id,
172            name=name,
173            description=f"Treatment: {name}",
174            config=config,
175            traffic_percentage=traffic
176        ))
177        return self
178
179    def add_metric(
180        self,
181        name: str,
182        description: str = "",
183        is_primary: bool = False,
184        higher_is_better: bool = True
185    ) -> "ExperimentBuilder":
186        """Add a metric to track."""
187        self.experiment.metrics.append(ExperimentMetric(
188            name=name,
189            description=description,
190            is_primary=is_primary,
191            higher_is_better=higher_is_better
192        ))
193        return self
194
195    def with_sample_size(self, size: int) -> "ExperimentBuilder":
196        self.experiment.target_sample_size = size
197        return self
198
199    def with_segments(self, segments: List[str]) -> "ExperimentBuilder":
200        self.experiment.user_segments = set(segments)
201        return self
202
203    def build(self) -> Experiment:
204        """Build and validate the experiment."""
205        errors = self.experiment.validate()
206        if errors:
207            raise ValueError(f"Invalid experiment: {errors}")
208        return self.experiment
209
210
211# Example: Create an experiment comparing agent versions
212def create_agent_version_experiment() -> Experiment:
213    return (
214        ExperimentBuilder("Agent V2 Rollout")
215        .with_description("Compare new agent v2 against v1")
216        .add_control(
217            name="Agent V1 (Current)",
218            config={"agent_version": "1.0", "model": "gpt-4"},
219            traffic=50.0
220        )
221        .add_treatment(
222            name="Agent V2 (New)",
223            config={"agent_version": "2.0", "model": "gpt-4"},
224            traffic=50.0
225        )
226        .add_metric("task_success_rate", is_primary=True, higher_is_better=True)
227        .add_metric("avg_latency_ms", higher_is_better=False)
228        .add_metric("user_satisfaction", higher_is_better=True)
229        .add_metric("cost_per_task", higher_is_better=False)
230        .with_sample_size(2000)
231        .build()
232    )

Traffic Splitting

Traffic splitting assigns users to experiment variants consistently. Here's how to implement robust traffic allocation:

🐍python
1"""
2Traffic splitting for A/B experiments.
3"""
4
5import hashlib
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Tuple
8
9
10@dataclass
11class AllocationResult:
12    """Result of traffic allocation."""
13    experiment_id: str
14    variant_id: str
15    variant_name: str
16    config: Dict[str, Any]
17    is_in_experiment: bool
18
19
20class TrafficAllocator:
21    """Allocates traffic to experiment variants."""
22
23    def __init__(self, salt: str = "experiment_salt"):
24        self.salt = salt
25        self.experiments: Dict[str, Experiment] = {}
26        self.overrides: Dict[str, Dict[str, str]] = {}  # user_id -> experiment -> variant
27
28    def register_experiment(self, experiment: Experiment):
29        """Register an experiment."""
30        self.experiments[experiment.id] = experiment
31
32    def add_override(
33        self,
34        user_id: str,
35        experiment_id: str,
36        variant_id: str
37    ):
38        """Add a user override for testing."""
39        if user_id not in self.overrides:
40            self.overrides[user_id] = {}
41        self.overrides[user_id][experiment_id] = variant_id
42
43    def allocate(
44        self,
45        user_id: str,
46        experiment_id: str,
47        user_segments: Optional[List[str]] = None
48    ) -> AllocationResult:
49        """Allocate a user to a variant."""
50
51        experiment = self.experiments.get(experiment_id)
52
53        if not experiment:
54            return AllocationResult(
55                experiment_id=experiment_id,
56                variant_id="",
57                variant_name="",
58                config={},
59                is_in_experiment=False
60            )
61
62        # Check if experiment is running
63        if experiment.status != ExperimentStatus.RUNNING:
64            return self._default_allocation(experiment)
65
66        # Check for override
67        if user_id in self.overrides:
68            override_variant = self.overrides[user_id].get(experiment_id)
69            if override_variant:
70                variant = next(
71                    (v for v in experiment.variants if v.id == override_variant),
72                    None
73                )
74                if variant:
75                    return AllocationResult(
76                        experiment_id=experiment_id,
77                        variant_id=variant.id,
78                        variant_name=variant.name,
79                        config=variant.config,
80                        is_in_experiment=True
81                    )
82
83        # Check segment eligibility
84        if experiment.user_segments:
85            user_segment_set = set(user_segments or [])
86            if not (user_segment_set & experiment.user_segments):
87                return self._default_allocation(experiment)
88
89        # Hash-based allocation for consistency
90        variant = self._hash_allocate(user_id, experiment)
91
92        return AllocationResult(
93            experiment_id=experiment_id,
94            variant_id=variant.id,
95            variant_name=variant.name,
96            config=variant.config,
97            is_in_experiment=True
98        )
99
100    def _hash_allocate(
101        self,
102        user_id: str,
103        experiment: Experiment
104    ) -> Variant:
105        """Allocate using consistent hashing."""
106
107        # Create hash from user_id + experiment_id
108        hash_input = f"{self.salt}:{experiment.id}:{user_id}"
109        hash_bytes = hashlib.sha256(hash_input.encode()).digest()
110
111        # Convert to bucket (0-99.99)
112        bucket = (int.from_bytes(hash_bytes[:4], "big") % 10000) / 100.0
113
114        # Find variant for bucket
115        cumulative = 0.0
116        for variant in experiment.variants:
117            cumulative += variant.traffic_percentage
118            if bucket < cumulative:
119                return variant
120
121        # Fallback to last variant
122        return experiment.variants[-1]
123
124    def _default_allocation(self, experiment: Experiment) -> AllocationResult:
125        """Return default (control) allocation."""
126        control = experiment.control_variant
127        if control:
128            return AllocationResult(
129                experiment_id=experiment.id,
130                variant_id=control.id,
131                variant_name=control.name,
132                config=control.config,
133                is_in_experiment=False
134            )
135        return AllocationResult(
136            experiment_id=experiment.id,
137            variant_id="",
138            variant_name="",
139            config={},
140            is_in_experiment=False
141        )
142
143
144class StickyAllocation:
145    """Maintains sticky allocation across sessions."""
146
147    def __init__(self, storage, allocator: TrafficAllocator):
148        self.storage = storage  # Redis, database, etc.
149        self.allocator = allocator
150
151    async def get_allocation(
152        self,
153        user_id: str,
154        experiment_id: str,
155        user_segments: Optional[List[str]] = None
156    ) -> AllocationResult:
157        """Get or create sticky allocation."""
158
159        # Check for existing allocation
160        key = f"allocation:{experiment_id}:{user_id}"
161        existing = await self.storage.get(key)
162
163        if existing:
164            experiment = self.allocator.experiments.get(experiment_id)
165            if experiment:
166                variant = next(
167                    (v for v in experiment.variants if v.id == existing),
168                    None
169                )
170                if variant:
171                    return AllocationResult(
172                        experiment_id=experiment_id,
173                        variant_id=variant.id,
174                        variant_name=variant.name,
175                        config=variant.config,
176                        is_in_experiment=True
177                    )
178
179        # Create new allocation
180        allocation = self.allocator.allocate(
181            user_id, experiment_id, user_segments
182        )
183
184        # Store for stickiness
185        if allocation.is_in_experiment:
186            await self.storage.set(key, allocation.variant_id)
187
188        return allocation

Statistical Analysis

Proper statistical analysis determines whether observed differences are significant or due to chance:

🐍python
1"""
2Statistical analysis for A/B experiments.
3"""
4
5import math
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Tuple
8import statistics
9
10
11@dataclass
12class VariantStats:
13    """Statistics for a variant."""
14    variant_id: str
15    sample_size: int
16    mean: float
17    std_dev: float
18    median: float
19    p25: float
20    p75: float
21    p95: float
22
23
24@dataclass
25class ExperimentResult:
26    """Statistical analysis result."""
27    experiment_id: str
28    metric_name: str
29    control_stats: VariantStats
30    treatment_stats: VariantStats
31    absolute_difference: float
32    relative_difference: float
33    p_value: float
34    confidence_interval: Tuple[float, float]
35    is_significant: bool
36    power: float
37    recommendation: str
38
39
40class ExperimentAnalyzer:
41    """Analyzes experiment results statistically."""
42
43    def __init__(self, significance_level: float = 0.05):
44        self.significance_level = significance_level
45
46    def analyze(
47        self,
48        experiment: Experiment,
49        control_values: List[float],
50        treatment_values: List[float],
51        metric_name: str
52    ) -> ExperimentResult:
53        """Perform statistical analysis on experiment data."""
54
55        # Calculate basic statistics
56        control_stats = self._calculate_stats("control", control_values)
57        treatment_stats = self._calculate_stats("treatment", treatment_values)
58
59        # Calculate differences
60        absolute_diff = treatment_stats.mean - control_stats.mean
61        relative_diff = (
62            absolute_diff / control_stats.mean * 100
63            if control_stats.mean != 0 else 0
64        )
65
66        # Perform t-test
67        p_value = self._welch_t_test(control_values, treatment_values)
68
69        # Calculate confidence interval
70        ci = self._confidence_interval(
71            control_values, treatment_values,
72            1 - self.significance_level
73        )
74
75        # Determine significance
76        is_significant = p_value < self.significance_level
77
78        # Calculate power
79        power = self._calculate_power(
80            control_stats, treatment_stats,
81            self.significance_level
82        )
83
84        # Generate recommendation
85        metric = next(
86            (m for m in experiment.metrics if m.name == metric_name),
87            None
88        )
89        recommendation = self._generate_recommendation(
90            is_significant, relative_diff, power,
91            metric.higher_is_better if metric else True
92        )
93
94        return ExperimentResult(
95            experiment_id=experiment.id,
96            metric_name=metric_name,
97            control_stats=control_stats,
98            treatment_stats=treatment_stats,
99            absolute_difference=absolute_diff,
100            relative_difference=relative_diff,
101            p_value=p_value,
102            confidence_interval=ci,
103            is_significant=is_significant,
104            power=power,
105            recommendation=recommendation
106        )
107
108    def _calculate_stats(
109        self,
110        variant_id: str,
111        values: List[float]
112    ) -> VariantStats:
113        """Calculate descriptive statistics."""
114
115        if not values:
116            return VariantStats(
117                variant_id=variant_id,
118                sample_size=0,
119                mean=0, std_dev=0, median=0,
120                p25=0, p75=0, p95=0
121            )
122
123        sorted_values = sorted(values)
124        n = len(values)
125
126        return VariantStats(
127            variant_id=variant_id,
128            sample_size=n,
129            mean=statistics.mean(values),
130            std_dev=statistics.stdev(values) if n > 1 else 0,
131            median=statistics.median(values),
132            p25=sorted_values[int(n * 0.25)],
133            p75=sorted_values[int(n * 0.75)],
134            p95=sorted_values[int(n * 0.95)]
135        )
136
137    def _welch_t_test(
138        self,
139        control: List[float],
140        treatment: List[float]
141    ) -> float:
142        """Perform Welch's t-test."""
143
144        n1, n2 = len(control), len(treatment)
145        if n1 < 2 or n2 < 2:
146            return 1.0
147
148        mean1 = statistics.mean(control)
149        mean2 = statistics.mean(treatment)
150        var1 = statistics.variance(control)
151        var2 = statistics.variance(treatment)
152
153        # Welch's t-statistic
154        se = math.sqrt(var1/n1 + var2/n2)
155        if se == 0:
156            return 1.0
157
158        t_stat = abs(mean2 - mean1) / se
159
160        # Welch-Satterthwaite degrees of freedom
161        num = (var1/n1 + var2/n2) ** 2
162        denom = (var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1)
163        df = num / denom if denom > 0 else 1
164
165        # Approximate p-value using normal distribution for large df
166        if df > 30:
167            # Use normal approximation
168            p_value = 2 * (1 - self._normal_cdf(t_stat))
169        else:
170            # Use t-distribution approximation
171            p_value = 2 * (1 - self._t_cdf(t_stat, df))
172
173        return p_value
174
175    def _normal_cdf(self, x: float) -> float:
176        """Standard normal CDF approximation."""
177        return 0.5 * (1 + math.erf(x / math.sqrt(2)))
178
179    def _t_cdf(self, t: float, df: float) -> float:
180        """Student's t CDF approximation."""
181        # Use normal approximation for simplicity
182        return self._normal_cdf(t)
183
184    def _confidence_interval(
185        self,
186        control: List[float],
187        treatment: List[float],
188        confidence: float
189    ) -> Tuple[float, float]:
190        """Calculate confidence interval for difference."""
191
192        n1, n2 = len(control), len(treatment)
193        if n1 < 2 or n2 < 2:
194            return (0, 0)
195
196        mean_diff = statistics.mean(treatment) - statistics.mean(control)
197        var1 = statistics.variance(control)
198        var2 = statistics.variance(treatment)
199
200        se = math.sqrt(var1/n1 + var2/n2)
201
202        # Z-score for confidence level
203        z = self._z_score(confidence)
204
205        margin = z * se
206
207        return (mean_diff - margin, mean_diff + margin)
208
209    def _z_score(self, confidence: float) -> float:
210        """Get z-score for confidence level."""
211        # Common values
212        z_scores = {
213            0.90: 1.645,
214            0.95: 1.96,
215            0.99: 2.576
216        }
217        return z_scores.get(confidence, 1.96)
218
219    def _calculate_power(
220        self,
221        control: VariantStats,
222        treatment: VariantStats,
223        alpha: float
224    ) -> float:
225        """Calculate statistical power."""
226
227        if control.sample_size < 2 or treatment.sample_size < 2:
228            return 0.0
229
230        # Pooled standard deviation
231        pooled_std = math.sqrt(
232            (control.std_dev ** 2 + treatment.std_dev ** 2) / 2
233        )
234
235        if pooled_std == 0:
236            return 1.0
237
238        # Effect size (Cohen's d)
239        effect_size = abs(treatment.mean - control.mean) / pooled_std
240
241        # Approximate power calculation
242        n = min(control.sample_size, treatment.sample_size)
243        ncp = effect_size * math.sqrt(n / 2)  # Non-centrality parameter
244
245        z_alpha = self._z_score(1 - alpha / 2)
246        power = self._normal_cdf(ncp - z_alpha)
247
248        return min(1.0, power)
249
250    def _generate_recommendation(
251        self,
252        is_significant: bool,
253        relative_diff: float,
254        power: float,
255        higher_is_better: bool
256    ) -> str:
257        """Generate recommendation based on analysis."""
258
259        if not is_significant:
260            if power < 0.8:
261                return "Inconclusive - need more data for statistical power"
262            else:
263                return "No significant difference detected"
264
265        if higher_is_better:
266            if relative_diff > 0:
267                return f"Treatment is {relative_diff:.1f}% better - recommend rollout"
268            else:
269                return f"Treatment is {abs(relative_diff):.1f}% worse - recommend keeping control"
270        else:
271            if relative_diff < 0:
272                return f"Treatment is {abs(relative_diff):.1f}% better - recommend rollout"
273            else:
274                return f"Treatment is {relative_diff:.1f}% worse - recommend keeping control"

Feature Flags

Feature flags enable gradual rollouts and instant rollbacks. Here's how to integrate them with experiments:

🐍python
1"""
2Feature flag integration for experiments.
3"""
4
5from dataclasses import dataclass, field
6from datetime import datetime
7from enum import Enum
8from typing import Any, Callable, Dict, List, Optional
9
10
11class RolloutStage(Enum):
12    """Stages of feature rollout."""
13    OFF = "off"
14    INTERNAL = "internal"  # Internal testing
15    CANARY = "canary"  # Small percentage
16    BETA = "beta"  # Beta users
17    GRADUAL = "gradual"  # Percentage rollout
18    FULL = "full"  # 100% rollout
19
20
21@dataclass
22class FeatureFlag:
23    """A feature flag configuration."""
24    id: str
25    name: str
26    description: str
27    stage: RolloutStage = RolloutStage.OFF
28    percentage: float = 0.0
29    allowed_users: List[str] = field(default_factory=list)
30    allowed_segments: List[str] = field(default_factory=list)
31    config: Dict[str, Any] = field(default_factory=dict)
32    experiment_id: Optional[str] = None  # Link to experiment
33
34    def is_enabled_for(
35        self,
36        user_id: str,
37        user_segments: List[str] = None,
38        allocator: Optional["TrafficAllocator"] = None
39    ) -> bool:
40        """Check if flag is enabled for a user."""
41
42        if self.stage == RolloutStage.OFF:
43            return False
44
45        if self.stage == RolloutStage.FULL:
46            return True
47
48        # Check allowed users
49        if user_id in self.allowed_users:
50            return True
51
52        # Check segments
53        if user_segments:
54            if set(user_segments) & set(self.allowed_segments):
55                return True
56
57        # Check percentage rollout
58        if self.stage in (RolloutStage.CANARY, RolloutStage.GRADUAL):
59            if allocator and self.experiment_id:
60                allocation = allocator.allocate(user_id, self.experiment_id)
61                return allocation.is_in_experiment and allocation.variant_id != "control"
62            else:
63                # Simple hash-based check
64                bucket = self._get_bucket(user_id)
65                return bucket < self.percentage
66
67        return False
68
69    def _get_bucket(self, user_id: str) -> float:
70        """Get bucket for user (0-100)."""
71        import hashlib
72        hash_bytes = hashlib.sha256(
73            f"{self.id}:{user_id}".encode()
74        ).digest()
75        return (int.from_bytes(hash_bytes[:4], "big") % 10000) / 100.0
76
77
78class FeatureFlagManager:
79    """Manages feature flags."""
80
81    def __init__(self, allocator: TrafficAllocator = None):
82        self.flags: Dict[str, FeatureFlag] = {}
83        self.allocator = allocator
84
85    def register(self, flag: FeatureFlag):
86        """Register a feature flag."""
87        self.flags[flag.id] = flag
88
89    def is_enabled(
90        self,
91        flag_id: str,
92        user_id: str,
93        user_segments: List[str] = None,
94        default: bool = False
95    ) -> bool:
96        """Check if a flag is enabled."""
97
98        flag = self.flags.get(flag_id)
99        if not flag:
100            return default
101
102        return flag.is_enabled_for(
103            user_id, user_segments, self.allocator
104        )
105
106    def get_config(
107        self,
108        flag_id: str,
109        user_id: str,
110        user_segments: List[str] = None
111    ) -> Dict[str, Any]:
112        """Get config for a flag if enabled."""
113
114        flag = self.flags.get(flag_id)
115        if not flag:
116            return {}
117
118        if flag.is_enabled_for(user_id, user_segments, self.allocator):
119            # If linked to experiment, get variant config
120            if flag.experiment_id and self.allocator:
121                allocation = self.allocator.allocate(
122                    user_id, flag.experiment_id
123                )
124                return {**flag.config, **allocation.config}
125            return flag.config
126
127        return {}
128
129    def set_stage(self, flag_id: str, stage: RolloutStage, percentage: float = 0.0):
130        """Change rollout stage."""
131
132        flag = self.flags.get(flag_id)
133        if flag:
134            flag.stage = stage
135            flag.percentage = percentage
136
137    def rollback(self, flag_id: str):
138        """Instantly disable a flag."""
139        self.set_stage(flag_id, RolloutStage.OFF)
140
141
142class GradualRollout:
143    """Manages gradual feature rollout."""
144
145    def __init__(
146        self,
147        flag_manager: FeatureFlagManager,
148        experiment_analyzer: ExperimentAnalyzer
149    ):
150        self.flag_manager = flag_manager
151        self.analyzer = experiment_analyzer
152        self.rollout_schedule: Dict[str, List[Tuple[datetime, float]]] = {}
153
154    def start_rollout(
155        self,
156        flag_id: str,
157        schedule: List[Tuple[int, float]]  # (hours_from_start, percentage)
158    ):
159        """Start a scheduled gradual rollout."""
160
161        start = datetime.utcnow()
162        self.rollout_schedule[flag_id] = [
163            (start + timedelta(hours=hours), pct)
164            for hours, pct in schedule
165        ]
166
167        # Set initial stage
168        self.flag_manager.set_stage(
169            flag_id,
170            RolloutStage.GRADUAL,
171            schedule[0][1] if schedule else 0
172        )
173
174    async def check_and_advance(
175        self,
176        flag_id: str,
177        control_data: List[float],
178        treatment_data: List[float]
179    ) -> Dict[str, Any]:
180        """Check metrics and potentially advance rollout."""
181
182        flag = self.flag_manager.flags.get(flag_id)
183        if not flag:
184            return {"status": "flag_not_found"}
185
186        # Check current schedule
187        schedule = self.rollout_schedule.get(flag_id, [])
188        now = datetime.utcnow()
189
190        # Find next stage
191        next_stage = None
192        for scheduled_time, percentage in schedule:
193            if scheduled_time <= now and percentage > flag.percentage:
194                next_stage = percentage
195
196        if not next_stage:
197            return {"status": "no_stage_change"}
198
199        # Analyze current metrics
200        if flag.experiment_id:
201            experiment = self.flag_manager.allocator.experiments.get(
202                flag.experiment_id
203            )
204            if experiment:
205                result = self.analyzer.analyze(
206                    experiment, control_data, treatment_data,
207                    experiment.metrics[0].name
208                )
209
210                # Check for degradation
211                if result.is_significant and result.relative_difference < -10:
212                    # Rollback!
213                    self.flag_manager.rollback(flag_id)
214                    return {
215                        "status": "rollback",
216                        "reason": f"Significant degradation: {result.relative_difference:.1f}%"
217                    }
218
219        # Advance rollout
220        self.flag_manager.set_stage(flag_id, RolloutStage.GRADUAL, next_stage)
221
222        return {
223            "status": "advanced",
224            "new_percentage": next_stage
225        }

Multi-Arm Bandits

Multi-arm bandits automatically optimize traffic allocation by directing more users to better-performing variants:

🐍python
1"""
2Multi-arm bandit algorithms for adaptive experiments.
3"""
4
5import math
6import random
7from abc import ABC, abstractmethod
8from dataclasses import dataclass, field
9from typing import Dict, List, Optional, Tuple
10
11
12@dataclass
13class ArmStats:
14    """Statistics for a bandit arm."""
15    arm_id: str
16    pulls: int = 0
17    total_reward: float = 0.0
18    successes: int = 0
19    failures: int = 0
20
21    @property
22    def mean_reward(self) -> float:
23        return self.total_reward / self.pulls if self.pulls > 0 else 0.0
24
25    @property
26    def success_rate(self) -> float:
27        total = self.successes + self.failures
28        return self.successes / total if total > 0 else 0.5
29
30
31class BanditAlgorithm(ABC):
32    """Abstract base for bandit algorithms."""
33
34    @abstractmethod
35    def select_arm(self, arms: Dict[str, ArmStats]) -> str:
36        """Select an arm to pull."""
37        pass
38
39    @abstractmethod
40    def update(self, arm_id: str, reward: float):
41        """Update arm statistics after observation."""
42        pass
43
44
45class EpsilonGreedy(BanditAlgorithm):
46    """Epsilon-greedy algorithm."""
47
48    def __init__(self, epsilon: float = 0.1):
49        self.epsilon = epsilon
50        self.arms: Dict[str, ArmStats] = {}
51
52    def add_arm(self, arm_id: str):
53        """Add a new arm."""
54        self.arms[arm_id] = ArmStats(arm_id=arm_id)
55
56    def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
57        """Select arm: random with probability epsilon, best otherwise."""
58        arms = arms or self.arms
59
60        if not arms:
61            raise ValueError("No arms available")
62
63        # Explore with probability epsilon
64        if random.random() < self.epsilon:
65            return random.choice(list(arms.keys()))
66
67        # Exploit: select arm with highest mean reward
68        best_arm = max(arms, key=lambda a: arms[a].mean_reward)
69        return best_arm
70
71    def update(self, arm_id: str, reward: float):
72        """Update arm statistics."""
73        if arm_id in self.arms:
74            self.arms[arm_id].pulls += 1
75            self.arms[arm_id].total_reward += reward
76
77
78class UCB1(BanditAlgorithm):
79    """Upper Confidence Bound algorithm."""
80
81    def __init__(self, exploration_factor: float = 2.0):
82        self.exploration_factor = exploration_factor
83        self.arms: Dict[str, ArmStats] = {}
84        self.total_pulls = 0
85
86    def add_arm(self, arm_id: str):
87        """Add a new arm."""
88        self.arms[arm_id] = ArmStats(arm_id=arm_id)
89
90    def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
91        """Select arm with highest upper confidence bound."""
92        arms = arms or self.arms
93
94        if not arms:
95            raise ValueError("No arms available")
96
97        # Pull each arm at least once
98        for arm_id, stats in arms.items():
99            if stats.pulls == 0:
100                return arm_id
101
102        # Calculate UCB for each arm
103        ucb_values = {}
104        for arm_id, stats in arms.items():
105            exploitation = stats.mean_reward
106            exploration = math.sqrt(
107                self.exploration_factor * math.log(self.total_pulls) / stats.pulls
108            )
109            ucb_values[arm_id] = exploitation + exploration
110
111        return max(ucb_values, key=ucb_values.get)
112
113    def update(self, arm_id: str, reward: float):
114        """Update arm statistics."""
115        if arm_id in self.arms:
116            self.arms[arm_id].pulls += 1
117            self.arms[arm_id].total_reward += reward
118            self.total_pulls += 1
119
120
121class ThompsonSampling(BanditAlgorithm):
122    """Thompson Sampling for Bernoulli bandits."""
123
124    def __init__(self, prior_alpha: float = 1.0, prior_beta: float = 1.0):
125        self.prior_alpha = prior_alpha
126        self.prior_beta = prior_beta
127        self.arms: Dict[str, ArmStats] = {}
128
129    def add_arm(self, arm_id: str):
130        """Add a new arm."""
131        self.arms[arm_id] = ArmStats(arm_id=arm_id)
132
133    def select_arm(self, arms: Dict[str, ArmStats] = None) -> str:
134        """Select arm by sampling from posterior distributions."""
135        arms = arms or self.arms
136
137        if not arms:
138            raise ValueError("No arms available")
139
140        # Sample from Beta distribution for each arm
141        samples = {}
142        for arm_id, stats in arms.items():
143            alpha = self.prior_alpha + stats.successes
144            beta = self.prior_beta + stats.failures
145            samples[arm_id] = random.betavariate(alpha, beta)
146
147        return max(samples, key=samples.get)
148
149    def update(self, arm_id: str, reward: float):
150        """Update arm statistics (reward should be 0 or 1)."""
151        if arm_id in self.arms:
152            self.arms[arm_id].pulls += 1
153            if reward > 0.5:  # Success threshold
154                self.arms[arm_id].successes += 1
155            else:
156                self.arms[arm_id].failures += 1
157            self.arms[arm_id].total_reward += reward
158
159
160class AdaptiveExperiment:
161    """Experiment using multi-arm bandit for adaptive allocation."""
162
163    def __init__(
164        self,
165        experiment: Experiment,
166        algorithm: BanditAlgorithm
167    ):
168        self.experiment = experiment
169        self.algorithm = algorithm
170
171        # Initialize arms
172        for variant in experiment.variants:
173            self.algorithm.add_arm(variant.id)
174
175    def allocate(self, user_id: str) -> Variant:
176        """Allocate user using bandit algorithm."""
177        selected_arm = self.algorithm.select_arm(self.algorithm.arms)
178
179        variant = next(
180            (v for v in self.experiment.variants if v.id == selected_arm),
181            self.experiment.variants[0]
182        )
183
184        return variant
185
186    def record_outcome(self, variant_id: str, success: bool, reward: float = None):
187        """Record experiment outcome."""
188        if reward is None:
189            reward = 1.0 if success else 0.0
190
191        self.algorithm.update(variant_id, reward)
192
193    def get_allocation_percentages(self) -> Dict[str, float]:
194        """Get current implied allocation percentages."""
195        total_pulls = sum(
196            self.algorithm.arms[arm_id].pulls
197            for arm_id in self.algorithm.arms
198        )
199
200        if total_pulls == 0:
201            n = len(self.algorithm.arms)
202            return {arm_id: 100.0/n for arm_id in self.algorithm.arms}
203
204        return {
205            arm_id: (stats.pulls / total_pulls) * 100
206            for arm_id, stats in self.algorithm.arms.items()
207        }

Summary

This section covered the key components of A/B testing for AI agents:

TopicKey Concepts
Experiment DesignVariants, metrics, sample size, guardrails
Traffic SplittingHash-based allocation, stickiness, overrides
Statistical AnalysisT-tests, confidence intervals, power analysis
Feature FlagsGradual rollout, instant rollback, segments
Multi-Arm BanditsEpsilon-greedy, UCB1, Thompson Sampling
Key Takeaways: A/B testing provides real-world validation of agent improvements. Use proper statistical analysis to avoid false conclusions, and consider multi-arm bandits for continuous optimization.

In the next section, we'll explore continuous evaluation systems that automatically track agent performance over time.