-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathdynamic_spawner.py
More file actions
205 lines (169 loc) · 6.64 KB
/
dynamic_spawner.py
File metadata and controls
205 lines (169 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Dynamic Spawner — model-level fallback for failed agents.
When an agent task fails, the standard remediation flow retries with the
same model. This module adds a **model cascade**: if the original model
fails, try a more capable (or different) Claude model before giving up.
This leverages the ``model`` parameter in ``ClaudeAgentOptions`` documented
in the Claude Agent SDK README.
Integration point:
dag_executor._handle_failure — call ``maybe_respawn(...)`` before
falling back to the existing remediation/retry logic.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from contracts import (
FailureCategory,
TaskInput,
TaskOutput,
TaskStatus,
classify_failure,
)
logger = logging.getLogger(__name__)
# ── Model Cascade ────────────────────────────────────────────────────────────
# Ordered from cheapest/fastest to most capable.
# The spawner tries the next model in the cascade when the current one fails.
DEFAULT_MODEL_CASCADE: list[str] = [
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
]
# Failure categories where switching models is likely to help.
# For example, TIMEOUT may benefit from a faster model, BUILD_ERROR from
# a smarter one. PERMISSION and DEPENDENCY_MISSING are infrastructure
# issues — switching models won't help.
_MODEL_SWITCH_ELIGIBLE: set[FailureCategory] = {
FailureCategory.BUILD_ERROR,
FailureCategory.TEST_FAILURE,
FailureCategory.API_MISMATCH,
FailureCategory.TIMEOUT,
FailureCategory.UNCLEAR_GOAL,
}
@dataclass
class SpawnAttempt:
"""Record of a model-switch attempt."""
task_id: str
original_model: str | None
new_model: str
reason: str
succeeded: bool | None = None # filled after execution
@dataclass
class DynamicSpawner:
"""Manages model-level fallback for failed tasks.
Tracks which models have been tried per task so we don't repeat.
"""
model_cascade: list[str] = field(default_factory=lambda: list(DEFAULT_MODEL_CASCADE))
# task_id -> set of models already tried
_tried: dict[str, set[str]] = field(default_factory=dict)
# History of all spawn attempts
history: list[SpawnAttempt] = field(default_factory=list)
def should_respawn(
self,
task: TaskInput,
output: TaskOutput,
current_model: str | None = None,
) -> bool:
"""Decide whether a model switch is worth trying.
Returns True if:
1. The failure category is eligible for model switching.
2. There is at least one untried model in the cascade.
"""
if output.status != TaskStatus.FAILED:
return False
category = self._classify_failure(output)
if category not in _MODEL_SWITCH_ELIGIBLE:
logger.debug(
"[DynamicSpawner] task %s failed with %s — not eligible for model switch",
task.id,
category,
)
return False
next_model = self._next_model(task.id, current_model)
if next_model is None:
logger.debug(
"[DynamicSpawner] task %s — all models exhausted",
task.id,
)
return False
return True
def get_respawn_model(
self,
task: TaskInput,
output: TaskOutput,
current_model: str | None = None,
) -> str | None:
"""Return the next model to try, or None if exhausted.
Also records the attempt in history and marks the model as tried.
"""
if not self.should_respawn(task, output, current_model):
return None
next_model = self._next_model(task.id, current_model)
if next_model is None:
return None
# Record
self._tried.setdefault(task.id, set())
if current_model:
self._tried[task.id].add(current_model)
self._tried[task.id].add(next_model)
reason = self._build_reason(output)
attempt = SpawnAttempt(
task_id=task.id,
original_model=current_model,
new_model=next_model,
reason=reason,
)
self.history.append(attempt)
logger.info(
"[DynamicSpawner] task %s: switching from %s -> %s (reason: %s)",
task.id,
current_model or "default",
next_model,
reason,
)
return next_model
def record_result(self, task_id: str, model: str, succeeded: bool) -> None:
"""Update the last attempt for this task with the result."""
for attempt in reversed(self.history):
if attempt.task_id == task_id and attempt.new_model == model:
attempt.succeeded = succeeded
break
def get_summary(self) -> dict[str, Any]:
"""Return a summary of all spawn attempts."""
total = len(self.history)
succeeded = sum(1 for a in self.history if a.succeeded is True)
failed = sum(1 for a in self.history if a.succeeded is False)
return {
"total_attempts": total,
"succeeded": succeeded,
"failed": failed,
"pending": total - succeeded - failed,
"attempts": [
{
"task_id": a.task_id,
"from": a.original_model,
"to": a.new_model,
"reason": a.reason,
"succeeded": a.succeeded,
}
for a in self.history
],
}
# ── Internal ─────────────────────────────────────────────────────────
def _next_model(self, task_id: str, current_model: str | None) -> str | None:
"""Find the next untried model in the cascade."""
tried = self._tried.get(task_id, set())
if current_model:
tried = tried | {current_model}
for model in self.model_cascade:
if model not in tried:
return model
return None
@staticmethod
def _classify_failure(output: TaskOutput) -> FailureCategory:
"""Classify failure using the canonical contracts.classify_failure."""
return classify_failure(output)
@staticmethod
def _build_reason(output: TaskOutput) -> str:
"""Build a human-readable reason for the model switch."""
if output.issues:
return f"Failed: {output.issues[0][:100]}"
return f"Failed: {output.summary[:100]}"