Introduction
Test-driven development (TDD) is a powerful methodology for coding agents. By writing or identifying tests first, then implementing code to pass them, agents have a clear, objective measure of success. This section covers how to make your coding agent test-aware and capable of TDD workflows.
Key Insight: Tests provide an objective verification mechanism for agents. Unlike subjective code quality assessments, a test either passes or fails—giving the agent clear feedback on whether its implementation is correct.
Why TDD for Agents
Test-driven development offers several advantages for coding agents:
| Advantage | Description |
|---|---|
| Objective Success Criteria | Tests provide binary pass/fail feedback, eliminating ambiguity |
| Iterative Improvement | Failed tests guide the agent toward the correct solution |
| Regression Prevention | Existing tests catch unintended side effects |
| Documentation | Tests describe expected behavior in executable form |
| Confidence | Passing tests confirm the implementation works |
| Bounded Problem | Tests define clear boundaries for what needs to be done |
The TDD Workflow for Agents
- Red: Identify or write failing tests that describe the desired behavior
- Green: Write the minimum code necessary to make tests pass
- Refactor: Improve the code while keeping tests green
- Repeat: Continue until all requirements are met
1from dataclasses import dataclass, field
2from typing import List, Dict, Any, Optional
3from enum import Enum
4
5
6class TestStatus(Enum):
7 """Status of a test."""
8 PASSED = "passed"
9 FAILED = "failed"
10 ERROR = "error"
11 SKIPPED = "skipped"
12 PENDING = "pending"
13
14
15@dataclass
16class TestResult:
17 """Result of a single test."""
18 name: str
19 status: TestStatus
20 duration: float = 0.0
21 error_message: Optional[str] = None
22 stack_trace: Optional[str] = None
23 file_path: Optional[str] = None
24 line_number: Optional[int] = None
25
26
27@dataclass
28class TestSuiteResult:
29 """Result of running a test suite."""
30 total: int
31 passed: int
32 failed: int
33 errors: int
34 skipped: int
35 duration: float
36 tests: List[TestResult] = field(default_factory=list)
37
38 @property
39 def success(self) -> bool:
40 return self.failed == 0 and self.errors == 0
41
42 @property
43 def pass_rate(self) -> float:
44 if self.total == 0:
45 return 0.0
46 return self.passed / self.total
47
48 def failed_tests(self) -> List[TestResult]:
49 return [t for t in self.tests if t.status in (TestStatus.FAILED, TestStatus.ERROR)]
50
51 def summary(self) -> str:
52 return f"{self.passed}/{self.total} passed ({self.pass_rate:.1%}), {self.failed} failed, {self.errors} errors"Test Detection and Analysis
Before running tests, agents need to detect what testing framework is used and find relevant test files:
1import os
2import re
3from pathlib import Path
4from typing import List, Dict, Any, Optional, Set
5from dataclasses import dataclass
6
7
8@dataclass
9class TestFramework:
10 """Information about a detected test framework."""
11 name: str
12 language: str
13 run_command: str
14 config_files: List[str]
15 test_patterns: List[str]
16 coverage_command: Optional[str] = None
17
18
19class TestDetector:
20 """
21 Detect testing frameworks and test files in a project.
22 """
23
24 FRAMEWORKS = {
25 "pytest": TestFramework(
26 name="pytest",
27 language="python",
28 run_command="pytest",
29 config_files=["pytest.ini", "pyproject.toml", "setup.cfg"],
30 test_patterns=["test_*.py", "*_test.py", "tests/**/*.py"],
31 coverage_command="pytest --cov"
32 ),
33 "jest": TestFramework(
34 name="jest",
35 language="javascript",
36 run_command="npx jest",
37 config_files=["jest.config.js", "jest.config.ts", "package.json"],
38 test_patterns=["**/*.test.js", "**/*.test.ts", "**/*.spec.js", "**/*.spec.ts"],
39 coverage_command="npx jest --coverage"
40 ),
41 "vitest": TestFramework(
42 name="vitest",
43 language="javascript",
44 run_command="npx vitest run",
45 config_files=["vitest.config.ts", "vitest.config.js", "vite.config.ts"],
46 test_patterns=["**/*.test.ts", "**/*.spec.ts"],
47 coverage_command="npx vitest run --coverage"
48 ),
49 "mocha": TestFramework(
50 name="mocha",
51 language="javascript",
52 run_command="npx mocha",
53 config_files=[".mocharc.js", ".mocharc.json", "package.json"],
54 test_patterns=["test/**/*.js", "tests/**/*.js"],
55 coverage_command="npx nyc mocha"
56 ),
57 "go_test": TestFramework(
58 name="go test",
59 language="go",
60 run_command="go test",
61 config_files=["go.mod"],
62 test_patterns=["*_test.go"],
63 coverage_command="go test -cover"
64 ),
65 "cargo_test": TestFramework(
66 name="cargo test",
67 language="rust",
68 run_command="cargo test",
69 config_files=["Cargo.toml"],
70 test_patterns=["**/tests/*.rs", "src/**/*_test.rs"],
71 coverage_command="cargo tarpaulin"
72 ),
73 }
74
75 def __init__(self, workspace: Path):
76 self.workspace = Path(workspace)
77
78 def detect_frameworks(self) -> List[TestFramework]:
79 """Detect all testing frameworks in the project."""
80 detected = []
81
82 for name, framework in self.FRAMEWORKS.items():
83 for config_file in framework.config_files:
84 if (self.workspace / config_file).exists():
85 # Additional validation for package.json
86 if config_file == "package.json":
87 if self._has_test_dep(name):
88 detected.append(framework)
89 else:
90 detected.append(framework)
91 break
92
93 return detected
94
95 def _has_test_dep(self, framework_name: str) -> bool:
96 """Check if package.json has the testing framework as dependency."""
97 import json
98
99 pkg_path = self.workspace / "package.json"
100 if not pkg_path.exists():
101 return False
102
103 try:
104 pkg = json.loads(pkg_path.read_text())
105 all_deps = {
106 **pkg.get("dependencies", {}),
107 **pkg.get("devDependencies", {}),
108 }
109
110 framework_packages = {
111 "jest": ["jest", "@jest/core"],
112 "vitest": ["vitest"],
113 "mocha": ["mocha"],
114 }
115
116 for dep in framework_packages.get(framework_name, []):
117 if dep in all_deps:
118 return True
119 except:
120 pass
121
122 return False
123
124 def find_test_files(
125 self,
126 framework: TestFramework = None
127 ) -> List[Path]:
128 """Find all test files for a given framework."""
129 import fnmatch
130
131 patterns = []
132 if framework:
133 patterns = framework.test_patterns
134 else:
135 # Use all known patterns
136 for fw in self.FRAMEWORKS.values():
137 patterns.extend(fw.test_patterns)
138
139 test_files = []
140 ignore_dirs = {".git", "node_modules", "__pycache__", ".venv", "venv", "dist"}
141
142 for root, dirs, files in os.walk(self.workspace):
143 # Filter ignored directories
144 dirs[:] = [d for d in dirs if d not in ignore_dirs]
145
146 for file in files:
147 file_path = Path(root) / file
148 rel_path = file_path.relative_to(self.workspace)
149
150 for pattern in patterns:
151 if fnmatch.fnmatch(str(rel_path), pattern):
152 test_files.append(file_path)
153 break
154 if fnmatch.fnmatch(file, pattern.split("/")[-1]):
155 test_files.append(file_path)
156 break
157
158 return list(set(test_files))
159
160 def find_related_tests(self, source_file: Path) -> List[Path]:
161 """Find test files related to a source file."""
162 related = []
163 source_name = source_file.stem
164
165 # Common test file naming patterns
166 test_patterns = [
167 f"test_{source_name}.py",
168 f"{source_name}_test.py",
169 f"{source_name}.test.js",
170 f"{source_name}.test.ts",
171 f"{source_name}.spec.js",
172 f"{source_name}.spec.ts",
173 f"{source_name}_test.go",
174 ]
175
176 test_files = self.find_test_files()
177
178 for test_file in test_files:
179 if test_file.name in test_patterns:
180 related.append(test_file)
181 # Also check if test file is in a parallel test directory
182 if "test" in str(test_file):
183 if source_name in test_file.stem:
184 related.append(test_file)
185
186 return related
187
188 def analyze_test_file(self, test_file: Path) -> Dict[str, Any]:
189 """Analyze a test file to extract test information."""
190 content = test_file.read_text()
191
192 analysis = {
193 "file": str(test_file),
194 "tests": [],
195 "fixtures": [],
196 "imports": [],
197 }
198
199 if test_file.suffix == ".py":
200 analysis = self._analyze_pytest_file(content, analysis)
201 elif test_file.suffix in (".js", ".ts", ".jsx", ".tsx"):
202 analysis = self._analyze_jest_file(content, analysis)
203
204 return analysis
205
206 def _analyze_pytest_file(self, content: str, analysis: Dict) -> Dict:
207 """Extract test information from a pytest file."""
208 # Find test functions
209 test_pattern = r"(?:async\s+)?def\s+(test_\w+)\s*\("
210 for match in re.finditer(test_pattern, content):
211 analysis["tests"].append({
212 "name": match.group(1),
213 "type": "function",
214 "line": content[:match.start()].count("\n") + 1
215 })
216
217 # Find test classes
218 class_pattern = r"class\s+(Test\w+)\s*[:\(]"
219 for match in re.finditer(class_pattern, content):
220 analysis["tests"].append({
221 "name": match.group(1),
222 "type": "class",
223 "line": content[:match.start()].count("\n") + 1
224 })
225
226 # Find fixtures
227 fixture_pattern = r"@pytest\.fixture.*\ndef\s+(\w+)"
228 for match in re.finditer(fixture_pattern, content):
229 analysis["fixtures"].append(match.group(1))
230
231 return analysis
232
233 def _analyze_jest_file(self, content: str, analysis: Dict) -> Dict:
234 """Extract test information from a Jest file."""
235 # Find test/it blocks
236 test_pattern = r"(?:test|it)\s*\(['"](.+?)['"]"
237 for match in re.finditer(test_pattern, content):
238 analysis["tests"].append({
239 "name": match.group(1),
240 "type": "test",
241 "line": content[:match.start()].count("\n") + 1
242 })
243
244 # Find describe blocks
245 describe_pattern = r"describe\s*\(['"](.+?)['"]"
246 for match in re.finditer(describe_pattern, content):
247 analysis["tests"].append({
248 "name": match.group(1),
249 "type": "describe",
250 "line": content[:match.start()].count("\n") + 1
251 })
252
253 return analysisRunning and Parsing Tests
After detecting the test framework, we need to run tests and parse the results into a structured format the agent can understand:
1import asyncio
2import json
3import re
4import xml.etree.ElementTree as ET
5from typing import Optional
6from pathlib import Path
7
8
9class TestRunner:
10 """
11 Run tests and parse results for the coding agent.
12 """
13
14 def __init__(self, workspace: Path, sandbox = None):
15 self.workspace = Path(workspace)
16 self.sandbox = sandbox
17 self.detector = TestDetector(workspace)
18
19 async def run_tests(
20 self,
21 framework: TestFramework = None,
22 test_path: str = None,
23 specific_test: str = None,
24 coverage: bool = False,
25 timeout: int = 300
26 ) -> TestSuiteResult:
27 """Run tests and return parsed results."""
28 if not framework:
29 frameworks = self.detector.detect_frameworks()
30 if not frameworks:
31 raise Exception("No test framework detected")
32 framework = frameworks[0]
33
34 # Build command
35 cmd = self._build_test_command(
36 framework, test_path, specific_test, coverage
37 )
38
39 # Run tests
40 if self.sandbox:
41 result = await self.sandbox.execute(cmd, timeout=timeout)
42 stdout = result.stdout
43 stderr = result.stderr
44 success = result.success
45 else:
46 process = await asyncio.create_subprocess_shell(
47 cmd,
48 stdout=asyncio.subprocess.PIPE,
49 stderr=asyncio.subprocess.PIPE,
50 cwd=self.workspace
51 )
52 stdout_bytes, stderr_bytes = await asyncio.wait_for(
53 process.communicate(),
54 timeout=timeout
55 )
56 stdout = stdout_bytes.decode()
57 stderr = stderr_bytes.decode()
58 success = process.returncode == 0
59
60 # Parse results based on framework
61 return self._parse_results(framework, stdout, stderr, success)
62
63 def _build_test_command(
64 self,
65 framework: TestFramework,
66 test_path: str = None,
67 specific_test: str = None,
68 coverage: bool = False
69 ) -> str:
70 """Build the test command."""
71 if coverage and framework.coverage_command:
72 cmd = framework.coverage_command
73 else:
74 cmd = framework.run_command
75
76 # Add output format for parsing
77 if framework.name == "pytest":
78 cmd += " -v --tb=short"
79 if test_path:
80 cmd += f" {test_path}"
81 if specific_test:
82 cmd += f" -k '{specific_test}'"
83
84 elif framework.name in ("jest", "vitest"):
85 cmd += " --reporter=json"
86 if test_path:
87 cmd += f" {test_path}"
88 if specific_test:
89 cmd += f" -t '{specific_test}'"
90
91 elif framework.name == "go test":
92 cmd += " -v -json"
93 if test_path:
94 cmd += f" {test_path}"
95 if specific_test:
96 cmd += f" -run {specific_test}"
97
98 return cmd
99
100 def _parse_results(
101 self,
102 framework: TestFramework,
103 stdout: str,
104 stderr: str,
105 success: bool
106 ) -> TestSuiteResult:
107 """Parse test output into structured results."""
108 if framework.name == "pytest":
109 return self._parse_pytest_output(stdout, stderr)
110 elif framework.name in ("jest", "vitest"):
111 return self._parse_jest_output(stdout, stderr)
112 elif framework.name == "go test":
113 return self._parse_go_test_output(stdout, stderr)
114 else:
115 return self._parse_generic_output(stdout, stderr, success)
116
117 def _parse_pytest_output(self, stdout: str, stderr: str) -> TestSuiteResult:
118 """Parse pytest verbose output."""
119 tests = []
120 output = stdout + "\n" + stderr
121
122 # Parse individual test results
123 # Format: path/test_file.py::test_name PASSED/FAILED
124 test_pattern = r"(\S+::\S+)\s+(PASSED|FAILED|ERROR|SKIPPED)"
125
126 for match in re.finditer(test_pattern, output):
127 name = match.group(1)
128 status_str = match.group(2)
129
130 status_map = {
131 "PASSED": TestStatus.PASSED,
132 "FAILED": TestStatus.FAILED,
133 "ERROR": TestStatus.ERROR,
134 "SKIPPED": TestStatus.SKIPPED,
135 }
136
137 tests.append(TestResult(
138 name=name,
139 status=status_map.get(status_str, TestStatus.ERROR)
140 ))
141
142 # Extract failure details
143 failure_pattern = r"FAILED (\S+) - (.+?)(?=\n(?:FAILED|=====|$))"
144 for match in re.finditer(failure_pattern, output, re.DOTALL):
145 test_name = match.group(1)
146 error_msg = match.group(2).strip()
147
148 for test in tests:
149 if test_name in test.name:
150 test.error_message = error_msg
151 break
152
153 # Parse summary line
154 summary_pattern = r"(\d+) passed(?:, (\d+) failed)?(?:, (\d+) error)?(?:, (\d+) skipped)?"
155 summary_match = re.search(summary_pattern, output)
156
157 if summary_match:
158 passed = int(summary_match.group(1) or 0)
159 failed = int(summary_match.group(2) or 0)
160 errors = int(summary_match.group(3) or 0)
161 skipped = int(summary_match.group(4) or 0)
162 else:
163 passed = len([t for t in tests if t.status == TestStatus.PASSED])
164 failed = len([t for t in tests if t.status == TestStatus.FAILED])
165 errors = len([t for t in tests if t.status == TestStatus.ERROR])
166 skipped = len([t for t in tests if t.status == TestStatus.SKIPPED])
167
168 return TestSuiteResult(
169 total=len(tests),
170 passed=passed,
171 failed=failed,
172 errors=errors,
173 skipped=skipped,
174 duration=0, # Could parse from output
175 tests=tests
176 )
177
178 def _parse_jest_output(self, stdout: str, stderr: str) -> TestSuiteResult:
179 """Parse Jest JSON output."""
180 tests = []
181
182 try:
183 # Find JSON in output
184 json_start = stdout.find("{")
185 json_end = stdout.rfind("}") + 1
186 if json_start >= 0 and json_end > json_start:
187 data = json.loads(stdout[json_start:json_end])
188 else:
189 raise ValueError("No JSON found")
190
191 for test_result in data.get("testResults", []):
192 for assertion in test_result.get("assertionResults", []):
193 status_map = {
194 "passed": TestStatus.PASSED,
195 "failed": TestStatus.FAILED,
196 "pending": TestStatus.PENDING,
197 }
198
199 tests.append(TestResult(
200 name=assertion.get("fullName", assertion.get("title", "unknown")),
201 status=status_map.get(assertion.get("status"), TestStatus.ERROR),
202 duration=assertion.get("duration", 0) / 1000,
203 error_message="\n".join(assertion.get("failureMessages", [])),
204 file_path=test_result.get("name")
205 ))
206
207 return TestSuiteResult(
208 total=data.get("numTotalTests", len(tests)),
209 passed=data.get("numPassedTests", 0),
210 failed=data.get("numFailedTests", 0),
211 errors=0,
212 skipped=data.get("numPendingTests", 0),
213 duration=data.get("testRuntime", 0) / 1000,
214 tests=tests
215 )
216
217 except (json.JSONDecodeError, ValueError):
218 # Fall back to regex parsing
219 return self._parse_generic_output(stdout, stderr, "fail" not in stdout.lower())
220
221 def _parse_go_test_output(self, stdout: str, stderr: str) -> TestSuiteResult:
222 """Parse go test JSON output."""
223 tests = []
224 output = stdout
225
226 for line in output.split("\n"):
227 if not line.strip():
228 continue
229
230 try:
231 event = json.loads(line)
232 action = event.get("Action")
233 test_name = event.get("Test")
234
235 if not test_name:
236 continue
237
238 if action == "pass":
239 tests.append(TestResult(
240 name=test_name,
241 status=TestStatus.PASSED,
242 duration=event.get("Elapsed", 0)
243 ))
244 elif action == "fail":
245 tests.append(TestResult(
246 name=test_name,
247 status=TestStatus.FAILED,
248 duration=event.get("Elapsed", 0),
249 error_message=event.get("Output", "")
250 ))
251 elif action == "skip":
252 tests.append(TestResult(
253 name=test_name,
254 status=TestStatus.SKIPPED
255 ))
256 except json.JSONDecodeError:
257 continue
258
259 passed = len([t for t in tests if t.status == TestStatus.PASSED])
260 failed = len([t for t in tests if t.status == TestStatus.FAILED])
261 skipped = len([t for t in tests if t.status == TestStatus.SKIPPED])
262
263 return TestSuiteResult(
264 total=len(tests),
265 passed=passed,
266 failed=failed,
267 errors=0,
268 skipped=skipped,
269 duration=sum(t.duration for t in tests),
270 tests=tests
271 )
272
273 def _parse_generic_output(
274 self,
275 stdout: str,
276 stderr: str,
277 success: bool
278 ) -> TestSuiteResult:
279 """Generic test output parser when specific parser not available."""
280 output = stdout + stderr
281
282 # Try to extract test counts from common patterns
283 patterns = [
284 r"(\d+) tests?,? (\d+) failures?",
285 r"Passed: (\d+), Failed: (\d+)",
286 r"(\d+) passing, (\d+) failing",
287 ]
288
289 for pattern in patterns:
290 match = re.search(pattern, output, re.IGNORECASE)
291 if match:
292 passed = int(match.group(1))
293 failed = int(match.group(2))
294 return TestSuiteResult(
295 total=passed + failed,
296 passed=passed,
297 failed=failed,
298 errors=0,
299 skipped=0,
300 duration=0,
301 tests=[]
302 )
303
304 # Fallback
305 return TestSuiteResult(
306 total=1,
307 passed=1 if success else 0,
308 failed=0 if success else 1,
309 errors=0,
310 skipped=0,
311 duration=0,
312 tests=[]
313 )Test Generation
Agents can generate tests to verify their implementations. This requires understanding the code being tested and the project's testing conventions:
1class TestGenerator:
2 """
3 Generate tests for code using LLM.
4 """
5
6 def __init__(self, llm_client, workspace: Path):
7 self.llm = llm_client
8 self.workspace = Path(workspace)
9 self.detector = TestDetector(workspace)
10
11 async def generate_tests(
12 self,
13 source_file: Path,
14 function_name: str = None,
15 test_type: str = "unit"
16 ) -> str:
17 """Generate tests for a source file or specific function."""
18 # Read source file
19 source_content = source_file.read_text()
20
21 # Detect framework and get examples
22 frameworks = self.detector.detect_frameworks()
23 framework = frameworks[0] if frameworks else None
24
25 # Find existing test files for patterns
26 existing_tests = self.detector.find_test_files(framework)
27 test_examples = self._get_test_examples(existing_tests)
28
29 # Build generation prompt
30 prompt = self._build_test_generation_prompt(
31 source_content,
32 source_file,
33 function_name,
34 test_type,
35 framework,
36 test_examples
37 )
38
39 # Generate tests
40 response = await self.llm.generate(prompt)
41
42 # Extract code from response
43 return self._extract_code(response)
44
45 def _build_test_generation_prompt(
46 self,
47 source_content: str,
48 source_file: Path,
49 function_name: str,
50 test_type: str,
51 framework: TestFramework,
52 test_examples: str
53 ) -> str:
54 """Build prompt for test generation."""
55 framework_name = framework.name if framework else "appropriate"
56 language = source_file.suffix[1:]
57
58 focus = f"the function '{function_name}'" if function_name else "all public functions"
59
60 return f"""Generate [test_type] tests for the following code.
61
62Source file: [source_file.name]
63Testing framework: [framework_name]
64Focus: [focus]
65
66SOURCE CODE:
67[source_content]
68
69[test_examples if provided]
70
71Requirements:
721. Follow the testing patterns used in this project
732. Test both success cases and edge cases
743. Include meaningful assertions
754. Use descriptive test names that explain what is being tested
765. Mock external dependencies appropriately
77
78Generate ONLY the test code, no explanations."""
79
80 def _get_test_examples(self, test_files: List[Path], limit: int = 2) -> str:
81 """Get examples from existing test files."""
82 examples = []
83
84 for test_file in test_files[:limit]:
85 try:
86 content = test_file.read_text()
87 # Take first 100 lines as example
88 lines = content.split("\n")[:100]
89 examples.append(f"# From {test_file.name}\n" + "\n".join(lines))
90 except:
91 continue
92
93 return "\n\n".join(examples)
94
95 def _extract_code(self, response: str) -> str:
96 """Extract code from LLM response."""
97 # Look for code blocks (pattern: triple-backtick + optional lang + newline + content + triple-backtick)
98 # Using raw string pattern for markdown code blocks
99 pattern = r'[BACKTICK][BACKTICK][BACKTICK](?:\w+)?\n(.*?)[BACKTICK][BACKTICK][BACKTICK]'
100 matches = re.findall(pattern.replace('[BACKTICK]', chr(96)), response, re.DOTALL)
101
102 if matches:
103 return "\n\n".join(matches)
104
105 # Return as-is if no code blocks found
106 return response.strip()
107
108 async def generate_test_for_bug(
109 self,
110 bug_description: str,
111 affected_file: Path,
112 expected_behavior: str
113 ) -> str:
114 """Generate a regression test for a bug."""
115 source_content = affected_file.read_text()
116
117 prompt = f"""Generate a regression test that verifies a bug fix.
118
119Bug description: [bug_description]
120Expected behavior: [expected_behavior]
121Affected file: [affected_file.name]
122
123SOURCE CODE:
124[source_content[:2000]]
125
126Generate a test that:
1271. Would FAIL if the bug exists
1282. Would PASS after the bug is fixed
1293. Clearly documents what the bug was
1304. Can be used for regression testing
131
132Generate ONLY the test code."""
133
134 response = await self.llm.generate(prompt)
135 return self._extract_code(response)
136
137 async def improve_test_coverage(
138 self,
139 source_file: Path,
140 existing_tests: Path,
141 coverage_report: Dict[str, Any] = None
142 ) -> str:
143 """Generate additional tests to improve coverage."""
144 source_content = source_file.read_text()
145 test_content = existing_tests.read_text()
146
147 uncovered_info = ""
148 if coverage_report:
149 uncovered_lines = coverage_report.get("uncovered_lines", [])
150 uncovered_info = f"\nUncovered lines: [uncovered_lines]"
151
152 prompt = f"""Analyze the existing tests and generate additional tests to improve coverage.
153
154SOURCE CODE:
155[source_content]
156
157EXISTING TESTS:
158[test_content]
159[uncovered_info]
160
161Generate NEW tests that:
1621. Cover code paths not tested by existing tests
1632. Test edge cases and error conditions
1643. Don't duplicate existing test coverage
1654. Follow the same style as existing tests
166
167Generate ONLY the additional test code to append."""
168
169 response = await self.llm.generate(prompt)
170 return self._extract_code(response)Test-Driven Agent Loop
The TDD agent loop combines test detection, running, and implementation into an iterative workflow:
1from typing import AsyncGenerator
2
3
4class TDDAgentLoop:
5 """
6 Test-driven development loop for coding agents.
7 """
8
9 MAX_ITERATIONS = 10
10
11 def __init__(
12 self,
13 llm_client,
14 workspace: Path,
15 sandbox,
16 file_tools
17 ):
18 self.llm = llm_client
19 self.workspace = Path(workspace)
20 self.sandbox = sandbox
21 self.file_tools = file_tools
22
23 self.test_runner = TestRunner(workspace, sandbox)
24 self.test_generator = TestGenerator(llm_client, workspace)
25 self.detector = TestDetector(workspace)
26
27 async def implement_with_tests(
28 self,
29 task_description: str,
30 target_file: Path,
31 test_file: Path = None
32 ) -> AsyncGenerator[Dict[str, Any], None]:
33 """
34 Implement a feature using TDD methodology.
35 """
36 iteration = 0
37
38 # Phase 1: Identify or create tests
39 yield {"phase": "setup", "message": "Identifying tests"}
40
41 if test_file and test_file.exists():
42 tests_exist = True
43 else:
44 # Find or generate tests
45 related_tests = self.detector.find_related_tests(target_file)
46 if related_tests:
47 test_file = related_tests[0]
48 tests_exist = True
49 else:
50 # Generate new test file
51 test_file = self._get_test_file_path(target_file)
52 test_code = await self.test_generator.generate_tests(
53 target_file,
54 test_type="unit"
55 )
56 await self.file_tools.write(str(test_file), test_code)
57 tests_exist = False
58
59 yield {
60 "phase": "setup",
61 "message": f"Using test file: [test_file]",
62 "tests_generated": not tests_exist
63 }
64
65 # Phase 2: Initial test run (should fail for new features)
66 yield {"phase": "red", "message": "Running initial tests"}
67
68 initial_results = await self.test_runner.run_tests(
69 test_path=str(test_file)
70 )
71
72 yield {
73 "phase": "red",
74 "results": initial_results.summary(),
75 "failed_tests": [t.name for t in initial_results.failed_tests()]
76 }
77
78 # Phase 3: Implementation loop
79 while iteration < self.MAX_ITERATIONS:
80 iteration += 1
81
82 yield {
83 "phase": "green",
84 "iteration": iteration,
85 "message": "Implementing to pass tests"
86 }
87
88 # Generate or improve implementation
89 implementation = await self._generate_implementation(
90 task_description,
91 target_file,
92 initial_results if iteration == 1 else test_results
93 )
94
95 # Apply implementation
96 await self.file_tools.write(str(target_file), implementation)
97
98 yield {
99 "phase": "green",
100 "iteration": iteration,
101 "message": "Running tests to verify"
102 }
103
104 # Run tests
105 test_results = await self.test_runner.run_tests(
106 test_path=str(test_file)
107 )
108
109 yield {
110 "phase": "green",
111 "iteration": iteration,
112 "results": test_results.summary(),
113 "success": test_results.success
114 }
115
116 if test_results.success:
117 # All tests pass!
118 break
119
120 # Analyze failures and prepare for next iteration
121 failure_analysis = await self._analyze_failures(
122 test_results,
123 target_file
124 )
125
126 yield {
127 "phase": "debug",
128 "iteration": iteration,
129 "analysis": failure_analysis
130 }
131
132 # Phase 4: Refactor (optional)
133 if test_results.success:
134 yield {"phase": "refactor", "message": "Refactoring with test safety net"}
135
136 refactored = await self._refactor_implementation(
137 target_file,
138 task_description
139 )
140
141 if refactored:
142 await self.file_tools.write(str(target_file), refactored)
143
144 # Verify tests still pass
145 final_results = await self.test_runner.run_tests(
146 test_path=str(test_file)
147 )
148
149 if not final_results.success:
150 # Rollback refactoring
151 await self.file_tools.write(str(target_file), implementation)
152 yield {
153 "phase": "refactor",
154 "message": "Refactoring broke tests, rolled back"
155 }
156 else:
157 yield {
158 "phase": "refactor",
159 "message": "Refactoring complete, tests still pass"
160 }
161
162 # Final summary
163 yield {
164 "phase": "complete",
165 "iterations": iteration,
166 "success": test_results.success,
167 "final_results": test_results.summary()
168 }
169
170 async def _generate_implementation(
171 self,
172 task: str,
173 target_file: Path,
174 test_results: TestSuiteResult
175 ) -> str:
176 """Generate implementation based on failing tests."""
177 current_content = target_file.read_text() if target_file.exists() else ""
178
179 failed_tests = test_results.failed_tests()
180 failure_info = "\n".join([
181 f"- {t.name}: {t.error_message or 'Failed'}"
182 for t in failed_tests
183 ])
184
185 prompt = f"""Implement or fix the following code to pass the failing tests.
186
187Task: [task]
188
189Current implementation:
190[current_content]
191
192Failing tests:
193[failure_info]
194
195Requirements:
1961. Make the failing tests pass
1972. Don't break any passing tests
1983. Write clean, maintainable code
1994. Handle edge cases appropriately
200
201Provide the complete updated file content."""
202
203 response = await self.llm.generate(prompt)
204 return self.test_generator._extract_code(response)
205
206 async def _analyze_failures(
207 self,
208 test_results: TestSuiteResult,
209 target_file: Path
210 ) -> Dict[str, Any]:
211 """Analyze test failures to guide next implementation attempt."""
212 failed = test_results.failed_tests()
213
214 analysis = {
215 "failed_count": len(failed),
216 "patterns": [],
217 "suggestions": []
218 }
219
220 # Look for common patterns
221 error_messages = [t.error_message for t in failed if t.error_message]
222
223 if any("undefined" in str(e).lower() for e in error_messages):
224 analysis["patterns"].append("missing_definition")
225 analysis["suggestions"].append("Check that all functions/variables are defined")
226
227 if any("type" in str(e).lower() for e in error_messages):
228 analysis["patterns"].append("type_error")
229 analysis["suggestions"].append("Check argument types and return types")
230
231 if any("assert" in str(e).lower() for e in error_messages):
232 analysis["patterns"].append("assertion_failure")
233 analysis["suggestions"].append("Implementation logic may be incorrect")
234
235 return analysis
236
237 async def _refactor_implementation(
238 self,
239 target_file: Path,
240 task: str
241 ) -> Optional[str]:
242 """Attempt to refactor the implementation for cleaner code."""
243 current_content = target_file.read_text()
244
245 prompt = f"""Refactor the following code for better readability and maintainability.
246Keep the same functionality - tests should still pass.
247
248Task context: [task]
249
250Current code:
251[current_content]
252
253Improvements to consider:
2541. Better naming
2552. Reduced complexity
2563. Clearer structure
2574. Following language idioms
258
259If the code is already clean, respond with "NO_REFACTOR_NEEDED".
260Otherwise, provide the complete refactored file."""
261
262 response = await self.llm.generate(prompt)
263
264 if "NO_REFACTOR_NEEDED" in response:
265 return None
266
267 return self.test_generator._extract_code(response)
268
269 def _get_test_file_path(self, source_file: Path) -> Path:
270 """Determine the path for a new test file."""
271 # Follow project conventions
272 name = source_file.stem
273 suffix = source_file.suffix
274
275 # Check for tests directory
276 tests_dir = self.workspace / "tests"
277 if not tests_dir.exists():
278 tests_dir = self.workspace / "test"
279 if not tests_dir.exists():
280 tests_dir = source_file.parent
281
282 if suffix == ".py":
283 return tests_dir / f"test_[name].py"
284 elif suffix in (".js", ".ts"):
285 return tests_dir / f"[name].test[suffix]"
286 else:
287 return tests_dir / f"test_[name][suffix]"Summary
In this section, we built comprehensive TDD capabilities for our coding agent:
- Test Detection: Automatically detecting test frameworks and finding relevant test files
- Test Running: Executing tests and parsing results from multiple frameworks (pytest, Jest, vitest, go test)
- Test Generation: LLM-powered test generation that follows project conventions
- TDD Loop: The complete red-green-refactor cycle with iterative implementation
- Failure Analysis: Understanding why tests fail to guide implementation improvements
In the next section, we'll implement the debugging and iteration loop that helps the agent diagnose and fix problems automatically.