Chapter 12
20 min read
Section 75 of 175

Test-Driven Agent Development

Building a Coding Agent

Introduction

Test-driven development (TDD) is a powerful methodology for coding agents. By writing or identifying tests first, then implementing code to pass them, agents have a clear, objective measure of success. This section covers how to make your coding agent test-aware and capable of TDD workflows.

Key Insight: Tests provide an objective verification mechanism for agents. Unlike subjective code quality assessments, a test either passes or fails—giving the agent clear feedback on whether its implementation is correct.

Why TDD for Agents

Test-driven development offers several advantages for coding agents:

AdvantageDescription
Objective Success CriteriaTests provide binary pass/fail feedback, eliminating ambiguity
Iterative ImprovementFailed tests guide the agent toward the correct solution
Regression PreventionExisting tests catch unintended side effects
DocumentationTests describe expected behavior in executable form
ConfidencePassing tests confirm the implementation works
Bounded ProblemTests define clear boundaries for what needs to be done

The TDD Workflow for Agents

  1. Red: Identify or write failing tests that describe the desired behavior
  2. Green: Write the minimum code necessary to make tests pass
  3. Refactor: Improve the code while keeping tests green
  4. Repeat: Continue until all requirements are met
🐍python
1from dataclasses import dataclass, field
2from typing import List, Dict, Any, Optional
3from enum import Enum
4
5
6class TestStatus(Enum):
7    """Status of a test."""
8    PASSED = "passed"
9    FAILED = "failed"
10    ERROR = "error"
11    SKIPPED = "skipped"
12    PENDING = "pending"
13
14
15@dataclass
16class TestResult:
17    """Result of a single test."""
18    name: str
19    status: TestStatus
20    duration: float = 0.0
21    error_message: Optional[str] = None
22    stack_trace: Optional[str] = None
23    file_path: Optional[str] = None
24    line_number: Optional[int] = None
25
26
27@dataclass
28class TestSuiteResult:
29    """Result of running a test suite."""
30    total: int
31    passed: int
32    failed: int
33    errors: int
34    skipped: int
35    duration: float
36    tests: List[TestResult] = field(default_factory=list)
37
38    @property
39    def success(self) -> bool:
40        return self.failed == 0 and self.errors == 0
41
42    @property
43    def pass_rate(self) -> float:
44        if self.total == 0:
45            return 0.0
46        return self.passed / self.total
47
48    def failed_tests(self) -> List[TestResult]:
49        return [t for t in self.tests if t.status in (TestStatus.FAILED, TestStatus.ERROR)]
50
51    def summary(self) -> str:
52        return f"{self.passed}/{self.total} passed ({self.pass_rate:.1%}), {self.failed} failed, {self.errors} errors"

Test Detection and Analysis

Before running tests, agents need to detect what testing framework is used and find relevant test files:

🐍python
1import os
2import re
3from pathlib import Path
4from typing import List, Dict, Any, Optional, Set
5from dataclasses import dataclass
6
7
8@dataclass
9class TestFramework:
10    """Information about a detected test framework."""
11    name: str
12    language: str
13    run_command: str
14    config_files: List[str]
15    test_patterns: List[str]
16    coverage_command: Optional[str] = None
17
18
19class TestDetector:
20    """
21    Detect testing frameworks and test files in a project.
22    """
23
24    FRAMEWORKS = {
25        "pytest": TestFramework(
26            name="pytest",
27            language="python",
28            run_command="pytest",
29            config_files=["pytest.ini", "pyproject.toml", "setup.cfg"],
30            test_patterns=["test_*.py", "*_test.py", "tests/**/*.py"],
31            coverage_command="pytest --cov"
32        ),
33        "jest": TestFramework(
34            name="jest",
35            language="javascript",
36            run_command="npx jest",
37            config_files=["jest.config.js", "jest.config.ts", "package.json"],
38            test_patterns=["**/*.test.js", "**/*.test.ts", "**/*.spec.js", "**/*.spec.ts"],
39            coverage_command="npx jest --coverage"
40        ),
41        "vitest": TestFramework(
42            name="vitest",
43            language="javascript",
44            run_command="npx vitest run",
45            config_files=["vitest.config.ts", "vitest.config.js", "vite.config.ts"],
46            test_patterns=["**/*.test.ts", "**/*.spec.ts"],
47            coverage_command="npx vitest run --coverage"
48        ),
49        "mocha": TestFramework(
50            name="mocha",
51            language="javascript",
52            run_command="npx mocha",
53            config_files=[".mocharc.js", ".mocharc.json", "package.json"],
54            test_patterns=["test/**/*.js", "tests/**/*.js"],
55            coverage_command="npx nyc mocha"
56        ),
57        "go_test": TestFramework(
58            name="go test",
59            language="go",
60            run_command="go test",
61            config_files=["go.mod"],
62            test_patterns=["*_test.go"],
63            coverage_command="go test -cover"
64        ),
65        "cargo_test": TestFramework(
66            name="cargo test",
67            language="rust",
68            run_command="cargo test",
69            config_files=["Cargo.toml"],
70            test_patterns=["**/tests/*.rs", "src/**/*_test.rs"],
71            coverage_command="cargo tarpaulin"
72        ),
73    }
74
75    def __init__(self, workspace: Path):
76        self.workspace = Path(workspace)
77
78    def detect_frameworks(self) -> List[TestFramework]:
79        """Detect all testing frameworks in the project."""
80        detected = []
81
82        for name, framework in self.FRAMEWORKS.items():
83            for config_file in framework.config_files:
84                if (self.workspace / config_file).exists():
85                    # Additional validation for package.json
86                    if config_file == "package.json":
87                        if self._has_test_dep(name):
88                            detected.append(framework)
89                    else:
90                        detected.append(framework)
91                    break
92
93        return detected
94
95    def _has_test_dep(self, framework_name: str) -> bool:
96        """Check if package.json has the testing framework as dependency."""
97        import json
98
99        pkg_path = self.workspace / "package.json"
100        if not pkg_path.exists():
101            return False
102
103        try:
104            pkg = json.loads(pkg_path.read_text())
105            all_deps = {
106                **pkg.get("dependencies", {}),
107                **pkg.get("devDependencies", {}),
108            }
109
110            framework_packages = {
111                "jest": ["jest", "@jest/core"],
112                "vitest": ["vitest"],
113                "mocha": ["mocha"],
114            }
115
116            for dep in framework_packages.get(framework_name, []):
117                if dep in all_deps:
118                    return True
119        except:
120            pass
121
122        return False
123
124    def find_test_files(
125        self,
126        framework: TestFramework = None
127    ) -> List[Path]:
128        """Find all test files for a given framework."""
129        import fnmatch
130
131        patterns = []
132        if framework:
133            patterns = framework.test_patterns
134        else:
135            # Use all known patterns
136            for fw in self.FRAMEWORKS.values():
137                patterns.extend(fw.test_patterns)
138
139        test_files = []
140        ignore_dirs = {".git", "node_modules", "__pycache__", ".venv", "venv", "dist"}
141
142        for root, dirs, files in os.walk(self.workspace):
143            # Filter ignored directories
144            dirs[:] = [d for d in dirs if d not in ignore_dirs]
145
146            for file in files:
147                file_path = Path(root) / file
148                rel_path = file_path.relative_to(self.workspace)
149
150                for pattern in patterns:
151                    if fnmatch.fnmatch(str(rel_path), pattern):
152                        test_files.append(file_path)
153                        break
154                    if fnmatch.fnmatch(file, pattern.split("/")[-1]):
155                        test_files.append(file_path)
156                        break
157
158        return list(set(test_files))
159
160    def find_related_tests(self, source_file: Path) -> List[Path]:
161        """Find test files related to a source file."""
162        related = []
163        source_name = source_file.stem
164
165        # Common test file naming patterns
166        test_patterns = [
167            f"test_{source_name}.py",
168            f"{source_name}_test.py",
169            f"{source_name}.test.js",
170            f"{source_name}.test.ts",
171            f"{source_name}.spec.js",
172            f"{source_name}.spec.ts",
173            f"{source_name}_test.go",
174        ]
175
176        test_files = self.find_test_files()
177
178        for test_file in test_files:
179            if test_file.name in test_patterns:
180                related.append(test_file)
181            # Also check if test file is in a parallel test directory
182            if "test" in str(test_file):
183                if source_name in test_file.stem:
184                    related.append(test_file)
185
186        return related
187
188    def analyze_test_file(self, test_file: Path) -> Dict[str, Any]:
189        """Analyze a test file to extract test information."""
190        content = test_file.read_text()
191
192        analysis = {
193            "file": str(test_file),
194            "tests": [],
195            "fixtures": [],
196            "imports": [],
197        }
198
199        if test_file.suffix == ".py":
200            analysis = self._analyze_pytest_file(content, analysis)
201        elif test_file.suffix in (".js", ".ts", ".jsx", ".tsx"):
202            analysis = self._analyze_jest_file(content, analysis)
203
204        return analysis
205
206    def _analyze_pytest_file(self, content: str, analysis: Dict) -> Dict:
207        """Extract test information from a pytest file."""
208        # Find test functions
209        test_pattern = r"(?:async\s+)?def\s+(test_\w+)\s*\("
210        for match in re.finditer(test_pattern, content):
211            analysis["tests"].append({
212                "name": match.group(1),
213                "type": "function",
214                "line": content[:match.start()].count("\n") + 1
215            })
216
217        # Find test classes
218        class_pattern = r"class\s+(Test\w+)\s*[:\(]"
219        for match in re.finditer(class_pattern, content):
220            analysis["tests"].append({
221                "name": match.group(1),
222                "type": "class",
223                "line": content[:match.start()].count("\n") + 1
224            })
225
226        # Find fixtures
227        fixture_pattern = r"@pytest\.fixture.*\ndef\s+(\w+)"
228        for match in re.finditer(fixture_pattern, content):
229            analysis["fixtures"].append(match.group(1))
230
231        return analysis
232
233    def _analyze_jest_file(self, content: str, analysis: Dict) -> Dict:
234        """Extract test information from a Jest file."""
235        # Find test/it blocks
236        test_pattern = r"(?:test|it)\s*\(['"](.+?)['"]"
237        for match in re.finditer(test_pattern, content):
238            analysis["tests"].append({
239                "name": match.group(1),
240                "type": "test",
241                "line": content[:match.start()].count("\n") + 1
242            })
243
244        # Find describe blocks
245        describe_pattern = r"describe\s*\(['"](.+?)['"]"
246        for match in re.finditer(describe_pattern, content):
247            analysis["tests"].append({
248                "name": match.group(1),
249                "type": "describe",
250                "line": content[:match.start()].count("\n") + 1
251            })
252
253        return analysis
Always check for a project's existing test infrastructure before generating new tests. Agents should follow the project's established patterns.

Running and Parsing Tests

After detecting the test framework, we need to run tests and parse the results into a structured format the agent can understand:

🐍python
1import asyncio
2import json
3import re
4import xml.etree.ElementTree as ET
5from typing import Optional
6from pathlib import Path
7
8
9class TestRunner:
10    """
11    Run tests and parse results for the coding agent.
12    """
13
14    def __init__(self, workspace: Path, sandbox = None):
15        self.workspace = Path(workspace)
16        self.sandbox = sandbox
17        self.detector = TestDetector(workspace)
18
19    async def run_tests(
20        self,
21        framework: TestFramework = None,
22        test_path: str = None,
23        specific_test: str = None,
24        coverage: bool = False,
25        timeout: int = 300
26    ) -> TestSuiteResult:
27        """Run tests and return parsed results."""
28        if not framework:
29            frameworks = self.detector.detect_frameworks()
30            if not frameworks:
31                raise Exception("No test framework detected")
32            framework = frameworks[0]
33
34        # Build command
35        cmd = self._build_test_command(
36            framework, test_path, specific_test, coverage
37        )
38
39        # Run tests
40        if self.sandbox:
41            result = await self.sandbox.execute(cmd, timeout=timeout)
42            stdout = result.stdout
43            stderr = result.stderr
44            success = result.success
45        else:
46            process = await asyncio.create_subprocess_shell(
47                cmd,
48                stdout=asyncio.subprocess.PIPE,
49                stderr=asyncio.subprocess.PIPE,
50                cwd=self.workspace
51            )
52            stdout_bytes, stderr_bytes = await asyncio.wait_for(
53                process.communicate(),
54                timeout=timeout
55            )
56            stdout = stdout_bytes.decode()
57            stderr = stderr_bytes.decode()
58            success = process.returncode == 0
59
60        # Parse results based on framework
61        return self._parse_results(framework, stdout, stderr, success)
62
63    def _build_test_command(
64        self,
65        framework: TestFramework,
66        test_path: str = None,
67        specific_test: str = None,
68        coverage: bool = False
69    ) -> str:
70        """Build the test command."""
71        if coverage and framework.coverage_command:
72            cmd = framework.coverage_command
73        else:
74            cmd = framework.run_command
75
76        # Add output format for parsing
77        if framework.name == "pytest":
78            cmd += " -v --tb=short"
79            if test_path:
80                cmd += f" {test_path}"
81            if specific_test:
82                cmd += f" -k '{specific_test}'"
83
84        elif framework.name in ("jest", "vitest"):
85            cmd += " --reporter=json"
86            if test_path:
87                cmd += f" {test_path}"
88            if specific_test:
89                cmd += f" -t '{specific_test}'"
90
91        elif framework.name == "go test":
92            cmd += " -v -json"
93            if test_path:
94                cmd += f" {test_path}"
95            if specific_test:
96                cmd += f" -run {specific_test}"
97
98        return cmd
99
100    def _parse_results(
101        self,
102        framework: TestFramework,
103        stdout: str,
104        stderr: str,
105        success: bool
106    ) -> TestSuiteResult:
107        """Parse test output into structured results."""
108        if framework.name == "pytest":
109            return self._parse_pytest_output(stdout, stderr)
110        elif framework.name in ("jest", "vitest"):
111            return self._parse_jest_output(stdout, stderr)
112        elif framework.name == "go test":
113            return self._parse_go_test_output(stdout, stderr)
114        else:
115            return self._parse_generic_output(stdout, stderr, success)
116
117    def _parse_pytest_output(self, stdout: str, stderr: str) -> TestSuiteResult:
118        """Parse pytest verbose output."""
119        tests = []
120        output = stdout + "\n" + stderr
121
122        # Parse individual test results
123        # Format: path/test_file.py::test_name PASSED/FAILED
124        test_pattern = r"(\S+::\S+)\s+(PASSED|FAILED|ERROR|SKIPPED)"
125
126        for match in re.finditer(test_pattern, output):
127            name = match.group(1)
128            status_str = match.group(2)
129
130            status_map = {
131                "PASSED": TestStatus.PASSED,
132                "FAILED": TestStatus.FAILED,
133                "ERROR": TestStatus.ERROR,
134                "SKIPPED": TestStatus.SKIPPED,
135            }
136
137            tests.append(TestResult(
138                name=name,
139                status=status_map.get(status_str, TestStatus.ERROR)
140            ))
141
142        # Extract failure details
143        failure_pattern = r"FAILED (\S+) - (.+?)(?=\n(?:FAILED|=====|$))"
144        for match in re.finditer(failure_pattern, output, re.DOTALL):
145            test_name = match.group(1)
146            error_msg = match.group(2).strip()
147
148            for test in tests:
149                if test_name in test.name:
150                    test.error_message = error_msg
151                    break
152
153        # Parse summary line
154        summary_pattern = r"(\d+) passed(?:, (\d+) failed)?(?:, (\d+) error)?(?:, (\d+) skipped)?"
155        summary_match = re.search(summary_pattern, output)
156
157        if summary_match:
158            passed = int(summary_match.group(1) or 0)
159            failed = int(summary_match.group(2) or 0)
160            errors = int(summary_match.group(3) or 0)
161            skipped = int(summary_match.group(4) or 0)
162        else:
163            passed = len([t for t in tests if t.status == TestStatus.PASSED])
164            failed = len([t for t in tests if t.status == TestStatus.FAILED])
165            errors = len([t for t in tests if t.status == TestStatus.ERROR])
166            skipped = len([t for t in tests if t.status == TestStatus.SKIPPED])
167
168        return TestSuiteResult(
169            total=len(tests),
170            passed=passed,
171            failed=failed,
172            errors=errors,
173            skipped=skipped,
174            duration=0,  # Could parse from output
175            tests=tests
176        )
177
178    def _parse_jest_output(self, stdout: str, stderr: str) -> TestSuiteResult:
179        """Parse Jest JSON output."""
180        tests = []
181
182        try:
183            # Find JSON in output
184            json_start = stdout.find("{")
185            json_end = stdout.rfind("}") + 1
186            if json_start >= 0 and json_end > json_start:
187                data = json.loads(stdout[json_start:json_end])
188            else:
189                raise ValueError("No JSON found")
190
191            for test_result in data.get("testResults", []):
192                for assertion in test_result.get("assertionResults", []):
193                    status_map = {
194                        "passed": TestStatus.PASSED,
195                        "failed": TestStatus.FAILED,
196                        "pending": TestStatus.PENDING,
197                    }
198
199                    tests.append(TestResult(
200                        name=assertion.get("fullName", assertion.get("title", "unknown")),
201                        status=status_map.get(assertion.get("status"), TestStatus.ERROR),
202                        duration=assertion.get("duration", 0) / 1000,
203                        error_message="\n".join(assertion.get("failureMessages", [])),
204                        file_path=test_result.get("name")
205                    ))
206
207            return TestSuiteResult(
208                total=data.get("numTotalTests", len(tests)),
209                passed=data.get("numPassedTests", 0),
210                failed=data.get("numFailedTests", 0),
211                errors=0,
212                skipped=data.get("numPendingTests", 0),
213                duration=data.get("testRuntime", 0) / 1000,
214                tests=tests
215            )
216
217        except (json.JSONDecodeError, ValueError):
218            # Fall back to regex parsing
219            return self._parse_generic_output(stdout, stderr, "fail" not in stdout.lower())
220
221    def _parse_go_test_output(self, stdout: str, stderr: str) -> TestSuiteResult:
222        """Parse go test JSON output."""
223        tests = []
224        output = stdout
225
226        for line in output.split("\n"):
227            if not line.strip():
228                continue
229
230            try:
231                event = json.loads(line)
232                action = event.get("Action")
233                test_name = event.get("Test")
234
235                if not test_name:
236                    continue
237
238                if action == "pass":
239                    tests.append(TestResult(
240                        name=test_name,
241                        status=TestStatus.PASSED,
242                        duration=event.get("Elapsed", 0)
243                    ))
244                elif action == "fail":
245                    tests.append(TestResult(
246                        name=test_name,
247                        status=TestStatus.FAILED,
248                        duration=event.get("Elapsed", 0),
249                        error_message=event.get("Output", "")
250                    ))
251                elif action == "skip":
252                    tests.append(TestResult(
253                        name=test_name,
254                        status=TestStatus.SKIPPED
255                    ))
256            except json.JSONDecodeError:
257                continue
258
259        passed = len([t for t in tests if t.status == TestStatus.PASSED])
260        failed = len([t for t in tests if t.status == TestStatus.FAILED])
261        skipped = len([t for t in tests if t.status == TestStatus.SKIPPED])
262
263        return TestSuiteResult(
264            total=len(tests),
265            passed=passed,
266            failed=failed,
267            errors=0,
268            skipped=skipped,
269            duration=sum(t.duration for t in tests),
270            tests=tests
271        )
272
273    def _parse_generic_output(
274        self,
275        stdout: str,
276        stderr: str,
277        success: bool
278    ) -> TestSuiteResult:
279        """Generic test output parser when specific parser not available."""
280        output = stdout + stderr
281
282        # Try to extract test counts from common patterns
283        patterns = [
284            r"(\d+) tests?,? (\d+) failures?",
285            r"Passed: (\d+), Failed: (\d+)",
286            r"(\d+) passing, (\d+) failing",
287        ]
288
289        for pattern in patterns:
290            match = re.search(pattern, output, re.IGNORECASE)
291            if match:
292                passed = int(match.group(1))
293                failed = int(match.group(2))
294                return TestSuiteResult(
295                    total=passed + failed,
296                    passed=passed,
297                    failed=failed,
298                    errors=0,
299                    skipped=0,
300                    duration=0,
301                    tests=[]
302                )
303
304        # Fallback
305        return TestSuiteResult(
306            total=1,
307            passed=1 if success else 0,
308            failed=0 if success else 1,
309            errors=0,
310            skipped=0,
311            duration=0,
312            tests=[]
313        )

Test Generation

Agents can generate tests to verify their implementations. This requires understanding the code being tested and the project's testing conventions:

🐍python
1class TestGenerator:
2    """
3    Generate tests for code using LLM.
4    """
5
6    def __init__(self, llm_client, workspace: Path):
7        self.llm = llm_client
8        self.workspace = Path(workspace)
9        self.detector = TestDetector(workspace)
10
11    async def generate_tests(
12        self,
13        source_file: Path,
14        function_name: str = None,
15        test_type: str = "unit"
16    ) -> str:
17        """Generate tests for a source file or specific function."""
18        # Read source file
19        source_content = source_file.read_text()
20
21        # Detect framework and get examples
22        frameworks = self.detector.detect_frameworks()
23        framework = frameworks[0] if frameworks else None
24
25        # Find existing test files for patterns
26        existing_tests = self.detector.find_test_files(framework)
27        test_examples = self._get_test_examples(existing_tests)
28
29        # Build generation prompt
30        prompt = self._build_test_generation_prompt(
31            source_content,
32            source_file,
33            function_name,
34            test_type,
35            framework,
36            test_examples
37        )
38
39        # Generate tests
40        response = await self.llm.generate(prompt)
41
42        # Extract code from response
43        return self._extract_code(response)
44
45    def _build_test_generation_prompt(
46        self,
47        source_content: str,
48        source_file: Path,
49        function_name: str,
50        test_type: str,
51        framework: TestFramework,
52        test_examples: str
53    ) -> str:
54        """Build prompt for test generation."""
55        framework_name = framework.name if framework else "appropriate"
56        language = source_file.suffix[1:]
57
58        focus = f"the function '{function_name}'" if function_name else "all public functions"
59
60        return f"""Generate [test_type] tests for the following code.
61
62Source file: [source_file.name]
63Testing framework: [framework_name]
64Focus: [focus]
65
66SOURCE CODE:
67[source_content]
68
69[test_examples if provided]
70
71Requirements:
721. Follow the testing patterns used in this project
732. Test both success cases and edge cases
743. Include meaningful assertions
754. Use descriptive test names that explain what is being tested
765. Mock external dependencies appropriately
77
78Generate ONLY the test code, no explanations."""
79
80    def _get_test_examples(self, test_files: List[Path], limit: int = 2) -> str:
81        """Get examples from existing test files."""
82        examples = []
83
84        for test_file in test_files[:limit]:
85            try:
86                content = test_file.read_text()
87                # Take first 100 lines as example
88                lines = content.split("\n")[:100]
89                examples.append(f"# From {test_file.name}\n" + "\n".join(lines))
90            except:
91                continue
92
93        return "\n\n".join(examples)
94
95    def _extract_code(self, response: str) -> str:
96        """Extract code from LLM response."""
97        # Look for code blocks (pattern: triple-backtick + optional lang + newline + content + triple-backtick)
98        # Using raw string pattern for markdown code blocks
99        pattern = r'[BACKTICK][BACKTICK][BACKTICK](?:\w+)?\n(.*?)[BACKTICK][BACKTICK][BACKTICK]'
100        matches = re.findall(pattern.replace('[BACKTICK]', chr(96)), response, re.DOTALL)
101
102        if matches:
103            return "\n\n".join(matches)
104
105        # Return as-is if no code blocks found
106        return response.strip()
107
108    async def generate_test_for_bug(
109        self,
110        bug_description: str,
111        affected_file: Path,
112        expected_behavior: str
113    ) -> str:
114        """Generate a regression test for a bug."""
115        source_content = affected_file.read_text()
116
117        prompt = f"""Generate a regression test that verifies a bug fix.
118
119Bug description: [bug_description]
120Expected behavior: [expected_behavior]
121Affected file: [affected_file.name]
122
123SOURCE CODE:
124[source_content[:2000]]
125
126Generate a test that:
1271. Would FAIL if the bug exists
1282. Would PASS after the bug is fixed
1293. Clearly documents what the bug was
1304. Can be used for regression testing
131
132Generate ONLY the test code."""
133
134        response = await self.llm.generate(prompt)
135        return self._extract_code(response)
136
137    async def improve_test_coverage(
138        self,
139        source_file: Path,
140        existing_tests: Path,
141        coverage_report: Dict[str, Any] = None
142    ) -> str:
143        """Generate additional tests to improve coverage."""
144        source_content = source_file.read_text()
145        test_content = existing_tests.read_text()
146
147        uncovered_info = ""
148        if coverage_report:
149            uncovered_lines = coverage_report.get("uncovered_lines", [])
150            uncovered_info = f"\nUncovered lines: [uncovered_lines]"
151
152        prompt = f"""Analyze the existing tests and generate additional tests to improve coverage.
153
154SOURCE CODE:
155[source_content]
156
157EXISTING TESTS:
158[test_content]
159[uncovered_info]
160
161Generate NEW tests that:
1621. Cover code paths not tested by existing tests
1632. Test edge cases and error conditions
1643. Don't duplicate existing test coverage
1654. Follow the same style as existing tests
166
167Generate ONLY the additional test code to append."""
168
169        response = await self.llm.generate(prompt)
170        return self._extract_code(response)
When generating tests, always analyze existing test files first to match the project's testing style and conventions.

Test-Driven Agent Loop

The TDD agent loop combines test detection, running, and implementation into an iterative workflow:

🐍python
1from typing import AsyncGenerator
2
3
4class TDDAgentLoop:
5    """
6    Test-driven development loop for coding agents.
7    """
8
9    MAX_ITERATIONS = 10
10
11    def __init__(
12        self,
13        llm_client,
14        workspace: Path,
15        sandbox,
16        file_tools
17    ):
18        self.llm = llm_client
19        self.workspace = Path(workspace)
20        self.sandbox = sandbox
21        self.file_tools = file_tools
22
23        self.test_runner = TestRunner(workspace, sandbox)
24        self.test_generator = TestGenerator(llm_client, workspace)
25        self.detector = TestDetector(workspace)
26
27    async def implement_with_tests(
28        self,
29        task_description: str,
30        target_file: Path,
31        test_file: Path = None
32    ) -> AsyncGenerator[Dict[str, Any], None]:
33        """
34        Implement a feature using TDD methodology.
35        """
36        iteration = 0
37
38        # Phase 1: Identify or create tests
39        yield {"phase": "setup", "message": "Identifying tests"}
40
41        if test_file and test_file.exists():
42            tests_exist = True
43        else:
44            # Find or generate tests
45            related_tests = self.detector.find_related_tests(target_file)
46            if related_tests:
47                test_file = related_tests[0]
48                tests_exist = True
49            else:
50                # Generate new test file
51                test_file = self._get_test_file_path(target_file)
52                test_code = await self.test_generator.generate_tests(
53                    target_file,
54                    test_type="unit"
55                )
56                await self.file_tools.write(str(test_file), test_code)
57                tests_exist = False
58
59        yield {
60            "phase": "setup",
61            "message": f"Using test file: [test_file]",
62            "tests_generated": not tests_exist
63        }
64
65        # Phase 2: Initial test run (should fail for new features)
66        yield {"phase": "red", "message": "Running initial tests"}
67
68        initial_results = await self.test_runner.run_tests(
69            test_path=str(test_file)
70        )
71
72        yield {
73            "phase": "red",
74            "results": initial_results.summary(),
75            "failed_tests": [t.name for t in initial_results.failed_tests()]
76        }
77
78        # Phase 3: Implementation loop
79        while iteration < self.MAX_ITERATIONS:
80            iteration += 1
81
82            yield {
83                "phase": "green",
84                "iteration": iteration,
85                "message": "Implementing to pass tests"
86            }
87
88            # Generate or improve implementation
89            implementation = await self._generate_implementation(
90                task_description,
91                target_file,
92                initial_results if iteration == 1 else test_results
93            )
94
95            # Apply implementation
96            await self.file_tools.write(str(target_file), implementation)
97
98            yield {
99                "phase": "green",
100                "iteration": iteration,
101                "message": "Running tests to verify"
102            }
103
104            # Run tests
105            test_results = await self.test_runner.run_tests(
106                test_path=str(test_file)
107            )
108
109            yield {
110                "phase": "green",
111                "iteration": iteration,
112                "results": test_results.summary(),
113                "success": test_results.success
114            }
115
116            if test_results.success:
117                # All tests pass!
118                break
119
120            # Analyze failures and prepare for next iteration
121            failure_analysis = await self._analyze_failures(
122                test_results,
123                target_file
124            )
125
126            yield {
127                "phase": "debug",
128                "iteration": iteration,
129                "analysis": failure_analysis
130            }
131
132        # Phase 4: Refactor (optional)
133        if test_results.success:
134            yield {"phase": "refactor", "message": "Refactoring with test safety net"}
135
136            refactored = await self._refactor_implementation(
137                target_file,
138                task_description
139            )
140
141            if refactored:
142                await self.file_tools.write(str(target_file), refactored)
143
144                # Verify tests still pass
145                final_results = await self.test_runner.run_tests(
146                    test_path=str(test_file)
147                )
148
149                if not final_results.success:
150                    # Rollback refactoring
151                    await self.file_tools.write(str(target_file), implementation)
152                    yield {
153                        "phase": "refactor",
154                        "message": "Refactoring broke tests, rolled back"
155                    }
156                else:
157                    yield {
158                        "phase": "refactor",
159                        "message": "Refactoring complete, tests still pass"
160                    }
161
162        # Final summary
163        yield {
164            "phase": "complete",
165            "iterations": iteration,
166            "success": test_results.success,
167            "final_results": test_results.summary()
168        }
169
170    async def _generate_implementation(
171        self,
172        task: str,
173        target_file: Path,
174        test_results: TestSuiteResult
175    ) -> str:
176        """Generate implementation based on failing tests."""
177        current_content = target_file.read_text() if target_file.exists() else ""
178
179        failed_tests = test_results.failed_tests()
180        failure_info = "\n".join([
181            f"- {t.name}: {t.error_message or 'Failed'}"
182            for t in failed_tests
183        ])
184
185        prompt = f"""Implement or fix the following code to pass the failing tests.
186
187Task: [task]
188
189Current implementation:
190[current_content]
191
192Failing tests:
193[failure_info]
194
195Requirements:
1961. Make the failing tests pass
1972. Don't break any passing tests
1983. Write clean, maintainable code
1994. Handle edge cases appropriately
200
201Provide the complete updated file content."""
202
203        response = await self.llm.generate(prompt)
204        return self.test_generator._extract_code(response)
205
206    async def _analyze_failures(
207        self,
208        test_results: TestSuiteResult,
209        target_file: Path
210    ) -&gt; Dict[str, Any]:
211        """Analyze test failures to guide next implementation attempt."""
212        failed = test_results.failed_tests()
213
214        analysis = {
215            "failed_count": len(failed),
216            "patterns": [],
217            "suggestions": []
218        }
219
220        # Look for common patterns
221        error_messages = [t.error_message for t in failed if t.error_message]
222
223        if any("undefined" in str(e).lower() for e in error_messages):
224            analysis["patterns"].append("missing_definition")
225            analysis["suggestions"].append("Check that all functions/variables are defined")
226
227        if any("type" in str(e).lower() for e in error_messages):
228            analysis["patterns"].append("type_error")
229            analysis["suggestions"].append("Check argument types and return types")
230
231        if any("assert" in str(e).lower() for e in error_messages):
232            analysis["patterns"].append("assertion_failure")
233            analysis["suggestions"].append("Implementation logic may be incorrect")
234
235        return analysis
236
237    async def _refactor_implementation(
238        self,
239        target_file: Path,
240        task: str
241    ) -&gt; Optional[str]:
242        """Attempt to refactor the implementation for cleaner code."""
243        current_content = target_file.read_text()
244
245        prompt = f"""Refactor the following code for better readability and maintainability.
246Keep the same functionality - tests should still pass.
247
248Task context: [task]
249
250Current code:
251[current_content]
252
253Improvements to consider:
2541. Better naming
2552. Reduced complexity
2563. Clearer structure
2574. Following language idioms
258
259If the code is already clean, respond with "NO_REFACTOR_NEEDED".
260Otherwise, provide the complete refactored file."""
261
262        response = await self.llm.generate(prompt)
263
264        if "NO_REFACTOR_NEEDED" in response:
265            return None
266
267        return self.test_generator._extract_code(response)
268
269    def _get_test_file_path(self, source_file: Path) -&gt; Path:
270        """Determine the path for a new test file."""
271        # Follow project conventions
272        name = source_file.stem
273        suffix = source_file.suffix
274
275        # Check for tests directory
276        tests_dir = self.workspace / "tests"
277        if not tests_dir.exists():
278            tests_dir = self.workspace / "test"
279        if not tests_dir.exists():
280            tests_dir = source_file.parent
281
282        if suffix == ".py":
283            return tests_dir / f"test_[name].py"
284        elif suffix in (".js", ".ts"):
285            return tests_dir / f"[name].test[suffix]"
286        else:
287            return tests_dir / f"test_[name][suffix]"
The TDD loop provides a self-correcting mechanism. Each iteration narrows the gap between the current implementation and the expected behavior defined by tests.

Summary

In this section, we built comprehensive TDD capabilities for our coding agent:

  • Test Detection: Automatically detecting test frameworks and finding relevant test files
  • Test Running: Executing tests and parsing results from multiple frameworks (pytest, Jest, vitest, go test)
  • Test Generation: LLM-powered test generation that follows project conventions
  • TDD Loop: The complete red-green-refactor cycle with iterative implementation
  • Failure Analysis: Understanding why tests fail to guide implementation improvements

In the next section, we'll implement the debugging and iteration loop that helps the agent diagnose and fix problems automatically.