|
4 | 4 | import asyncio |
5 | 5 | from collections.abc import AsyncIterator |
6 | 6 | from dataclasses import dataclass, field |
7 | | -from typing import TYPE_CHECKING, Any, cast |
| 7 | +from typing import TYPE_CHECKING, Any, Literal, cast |
8 | 8 |
|
9 | 9 | from typing_extensions import TypeVar |
10 | 10 |
|
@@ -164,24 +164,61 @@ class RunResultStreaming(RunResultBase): |
164 | 164 | _output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) |
165 | 165 | _stored_exception: Exception | None = field(default=None, repr=False) |
166 | 166 |
|
| 167 | + # Soft cancel state |
| 168 | + _cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False) |
| 169 | + |
167 | 170 | @property |
168 | 171 | def last_agent(self) -> Agent[Any]: |
169 | 172 | """The last agent that was run. Updates as the agent run progresses, so the true last agent |
170 | 173 | is only available after the agent run is complete. |
171 | 174 | """ |
172 | 175 | return self.current_agent |
173 | 176 |
|
174 | | - def cancel(self) -> None: |
175 | | - """Cancels the streaming run, stopping all background tasks and marking the run as |
176 | | - complete.""" |
177 | | - self._cleanup_tasks() # Cancel all running tasks |
178 | | - self.is_complete = True # Mark the run as complete to stop event streaming |
| 177 | + def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None: |
| 178 | + """Cancel the streaming run. |
179 | 179 |
|
180 | | - # Optionally, clear the event queue to prevent processing stale events |
181 | | - while not self._event_queue.empty(): |
182 | | - self._event_queue.get_nowait() |
183 | | - while not self._input_guardrail_queue.empty(): |
184 | | - self._input_guardrail_queue.get_nowait() |
| 180 | + Args: |
| 181 | + mode: Cancellation strategy: |
| 182 | + - "immediate": Stop immediately, cancel all tasks, clear queues (default) |
| 183 | + - "after_turn": Complete current turn gracefully before stopping |
| 184 | + * Allows LLM response to finish |
| 185 | + * Executes pending tool calls |
| 186 | + * Saves session state properly |
| 187 | + * Tracks usage accurately |
| 188 | + * Stops before next turn begins |
| 189 | +
|
| 190 | + Example: |
| 191 | + ```python |
| 192 | + result = Runner.run_streamed(agent, "Task", session=session) |
| 193 | +
|
| 194 | + async for event in result.stream_events(): |
| 195 | + if user_interrupted(): |
| 196 | + result.cancel(mode="after_turn") # Graceful |
| 197 | + # result.cancel() # Immediate (default) |
| 198 | + ``` |
| 199 | +
|
| 200 | + Note: After calling cancel(), you should continue consuming stream_events() |
| 201 | + to allow the cancellation to complete properly. |
| 202 | + """ |
| 203 | + # Store the cancel mode for the background task to check |
| 204 | + self._cancel_mode = mode |
| 205 | + |
| 206 | + if mode == "immediate": |
| 207 | + # Existing behavior - immediate shutdown |
| 208 | + self._cleanup_tasks() # Cancel all running tasks |
| 209 | + self.is_complete = True # Mark the run as complete to stop event streaming |
| 210 | + |
| 211 | + # Optionally, clear the event queue to prevent processing stale events |
| 212 | + while not self._event_queue.empty(): |
| 213 | + self._event_queue.get_nowait() |
| 214 | + while not self._input_guardrail_queue.empty(): |
| 215 | + self._input_guardrail_queue.get_nowait() |
| 216 | + |
| 217 | + elif mode == "after_turn": |
| 218 | + # Soft cancel - just set the flag |
| 219 | + # The streaming loop will check this and stop gracefully |
| 220 | + # Don't call _cleanup_tasks() or clear queues yet |
| 221 | + pass |
185 | 222 |
|
186 | 223 | async def stream_events(self) -> AsyncIterator[StreamEvent]: |
187 | 224 | """Stream deltas for new items as they are generated. We're using the types from the |
|
0 commit comments