Skip to content

Commit 3b7aed7

Browse files
authored
Add graceful cancel mode for streaming runs (openai#1896)
1 parent d9a4144 commit 3b7aed7

File tree

4 files changed

+575
-11
lines changed

4 files changed

+575
-11
lines changed

src/agents/result.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
from collections.abc import AsyncIterator
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, cast
7+
from typing import TYPE_CHECKING, Any, Literal, cast
88

99
from typing_extensions import TypeVar
1010

@@ -164,24 +164,61 @@ class RunResultStreaming(RunResultBase):
164164
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
165165
_stored_exception: Exception | None = field(default=None, repr=False)
166166

167+
# Soft cancel state
168+
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
169+
167170
@property
168171
def last_agent(self) -> Agent[Any]:
169172
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
170173
is only available after the agent run is complete.
171174
"""
172175
return self.current_agent
173176

174-
def cancel(self) -> None:
175-
"""Cancels the streaming run, stopping all background tasks and marking the run as
176-
complete."""
177-
self._cleanup_tasks() # Cancel all running tasks
178-
self.is_complete = True # Mark the run as complete to stop event streaming
177+
def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None:
178+
"""Cancel the streaming run.
179179
180-
# Optionally, clear the event queue to prevent processing stale events
181-
while not self._event_queue.empty():
182-
self._event_queue.get_nowait()
183-
while not self._input_guardrail_queue.empty():
184-
self._input_guardrail_queue.get_nowait()
180+
Args:
181+
mode: Cancellation strategy:
182+
- "immediate": Stop immediately, cancel all tasks, clear queues (default)
183+
- "after_turn": Complete current turn gracefully before stopping
184+
* Allows LLM response to finish
185+
* Executes pending tool calls
186+
* Saves session state properly
187+
* Tracks usage accurately
188+
* Stops before next turn begins
189+
190+
Example:
191+
```python
192+
result = Runner.run_streamed(agent, "Task", session=session)
193+
194+
async for event in result.stream_events():
195+
if user_interrupted():
196+
result.cancel(mode="after_turn") # Graceful
197+
# result.cancel() # Immediate (default)
198+
```
199+
200+
Note: After calling cancel(), you should continue consuming stream_events()
201+
to allow the cancellation to complete properly.
202+
"""
203+
# Store the cancel mode for the background task to check
204+
self._cancel_mode = mode
205+
206+
if mode == "immediate":
207+
# Existing behavior - immediate shutdown
208+
self._cleanup_tasks() # Cancel all running tasks
209+
self.is_complete = True # Mark the run as complete to stop event streaming
210+
211+
# Optionally, clear the event queue to prevent processing stale events
212+
while not self._event_queue.empty():
213+
self._event_queue.get_nowait()
214+
while not self._input_guardrail_queue.empty():
215+
self._input_guardrail_queue.get_nowait()
216+
217+
elif mode == "after_turn":
218+
# Soft cancel - just set the flag
219+
# The streaming loop will check this and stop gracefully
220+
# Don't call _cleanup_tasks() or clear queues yet
221+
pass
185222

186223
async def stream_events(self) -> AsyncIterator[StreamEvent]:
187224
"""Stream deltas for new items as they are generated. We're using the types from the

src/agents/run.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,12 @@ async def _start_streaming(
951951
await AgentRunner._save_result_to_session(session, starting_input, [])
952952

953953
while True:
954+
# Check for soft cancel before starting new turn
955+
if streamed_result._cancel_mode == "after_turn":
956+
streamed_result.is_complete = True
957+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
958+
break
959+
954960
if streamed_result.is_complete:
955961
break
956962

@@ -1026,13 +1032,33 @@ async def _start_streaming(
10261032
server_conversation_tracker.track_server_items(turn_result.model_response)
10271033

10281034
if isinstance(turn_result.next_step, NextStepHandoff):
1035+
# Save the conversation to session if enabled (before handoff)
1036+
# Note: Non-streaming path doesn't save handoff turns immediately,
1037+
# but streaming needs to for graceful cancellation support
1038+
if session is not None:
1039+
should_skip_session_save = (
1040+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1041+
streamed_result
1042+
)
1043+
)
1044+
if should_skip_session_save is False:
1045+
await AgentRunner._save_result_to_session(
1046+
session, [], turn_result.new_step_items
1047+
)
1048+
10291049
current_agent = turn_result.next_step.new_agent
10301050
current_span.finish(reset_current=True)
10311051
current_span = None
10321052
should_run_agent_start_hooks = True
10331053
streamed_result._event_queue.put_nowait(
10341054
AgentUpdatedStreamEvent(new_agent=current_agent)
10351055
)
1056+
1057+
# Check for soft cancel after handoff
1058+
if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap]
1059+
streamed_result.is_complete = True
1060+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
1061+
break
10361062
elif isinstance(turn_result.next_step, NextStepFinalOutput):
10371063
streamed_result._output_guardrails_task = asyncio.create_task(
10381064
cls._run_output_guardrails(
@@ -1078,6 +1104,12 @@ async def _start_streaming(
10781104
await AgentRunner._save_result_to_session(
10791105
session, [], turn_result.new_step_items
10801106
)
1107+
1108+
# Check for soft cancel after turn completion
1109+
if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap]
1110+
streamed_result.is_complete = True
1111+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
1112+
break
10811113
except AgentsException as exc:
10821114
streamed_result.is_complete = True
10831115
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())

tests/test_cancel_streaming.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,20 @@ async def test_cancel_cleans_up_resources():
114114
assert result._input_guardrail_queue.empty(), (
115115
"Input guardrail queue should be empty after cancel."
116116
)
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_cancel_immediate_mode_explicit():
121+
"""Test explicit immediate mode behaves same as default."""
122+
model = FakeModel()
123+
agent = Agent(name="Joker", model=model)
124+
125+
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
126+
127+
async for _ in result.stream_events():
128+
result.cancel(mode="immediate")
129+
break
130+
131+
assert result.is_complete
132+
assert result._event_queue.empty()
133+
assert result._cancel_mode == "immediate"

0 commit comments

Comments
 (0)