Skip to content

Commit 2230b87

Browse files
authored
Cancel subsequent tools (#166)
* Preconfigure a function tool * use our own @durable_function_tool that makes sure that cancelation is propgoated * Make sure to cancel any subsequent tools on failure * Cancel all pending tools on attempt finish
1 parent f3e0382 commit 2230b87

File tree

6 files changed

+163
-57
lines changed

6 files changed

+163
-57
lines changed

python/restate/ext/adk/plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ async def release_task():
8181
await event.wait()
8282
finally:
8383
self._models.pop(id, None)
84-
self._turnstiles.pop(id, None)
84+
maybe_turnstile = self._turnstiles.pop(id, None)
85+
if maybe_turnstile is not None:
86+
maybe_turnstile.cancel_all()
8587

8688
_ = asyncio.create_task(release_task())
8789
return None
@@ -161,7 +163,7 @@ async def on_tool_error_callback(
161163
assert turnstile is not None, "Turnstile not found for tool invocation."
162164
id = tool_context.function_call_id
163165
assert id is not None, "Function call ID is required for tool invocation."
164-
turnstile.allow_next_after(id)
166+
turnstile.cancel_all_after(id)
165167
return None
166168

167169
async def close(self):

python/restate/ext/openai/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from restate import ObjectContext, Context
1818
from restate.server_context import current_context
1919

20-
from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors
20+
from .runner_wrapper import DurableRunner
2121
from .models import LlmRetryOpts
22+
from .functions import propagate_cancellation, raise_terminal_errors, durable_function_tool
2223

2324

2425
def restate_object_context() -> ObjectContext:
@@ -42,6 +43,7 @@ def restate_context() -> Context:
4243
"LlmRetryOpts",
4344
"restate_object_context",
4445
"restate_context",
45-
"continue_on_terminal_errors",
4646
"raise_terminal_errors",
47+
"propagate_cancellation",
48+
"durable_function_tool",
4749
]

python/restate/ext/openai/functions.py

Lines changed: 128 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,132 @@
99
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010
#
1111

12-
from typing import List, Any
12+
from typing import List, Any, overload, Callable, Union, TypeVar, Awaitable
1313
import dataclasses
1414

1515
from agents import (
1616
Handoff,
1717
TContext,
1818
Agent,
19+
RunContextWrapper,
20+
ModelBehaviorError,
21+
AgentBase,
1922
)
2023

21-
from agents.tool import FunctionTool, Tool
24+
from agents.function_schema import DocstringStyle
25+
26+
from agents.tool import (
27+
FunctionTool,
28+
Tool,
29+
ToolFunction,
30+
ToolErrorFunction,
31+
function_tool as oai_function_tool,
32+
default_tool_error_function,
33+
)
2234
from agents.tool_context import ToolContext
2335
from agents.items import TResponseOutputItem
2436

25-
from .models import State
37+
from restate import TerminalError
38+
39+
from .models import State, AgentsTerminalException
40+
41+
T = TypeVar("T")
42+
43+
MaybeAwaitable = Union[Awaitable[T], T]
44+
45+
46+
def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
47+
"""A custom function to provide a user-friendly error message."""
48+
# Raise terminal errors and cancellations
49+
if isinstance(error, TerminalError):
50+
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
51+
# so we create a new exception that inherits from both
52+
raise AgentsTerminalException(error.message)
53+
54+
if isinstance(error, ModelBehaviorError):
55+
return f"An error occurred while calling the tool: {str(error)}"
56+
57+
raise error
58+
59+
60+
def propagate_cancellation(failure_error_function: ToolErrorFunction | None = None) -> ToolErrorFunction:
61+
_fn = failure_error_function if failure_error_function is not None else default_tool_error_function
62+
63+
def inner(context: RunContextWrapper[Any], error: Exception):
64+
"""Raise cancellations as exceptions."""
65+
if isinstance(error, TerminalError):
66+
if error.status_code == 409:
67+
raise error from None
68+
return _fn(context, error)
69+
70+
return inner
71+
72+
73+
@overload
74+
def durable_function_tool(
75+
func: ToolFunction[...],
76+
*,
77+
name_override: str | None = None,
78+
description_override: str | None = None,
79+
docstring_style: DocstringStyle | None = None,
80+
use_docstring_info: bool = True,
81+
failure_error_function: ToolErrorFunction | None = None,
82+
strict_mode: bool = True,
83+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
84+
) -> FunctionTool:
85+
"""Overload for usage as @function_tool (no parentheses)."""
86+
...
87+
88+
89+
@overload
90+
def durable_function_tool(
91+
*,
92+
name_override: str | None = None,
93+
description_override: str | None = None,
94+
docstring_style: DocstringStyle | None = None,
95+
use_docstring_info: bool = True,
96+
failure_error_function: ToolErrorFunction | None = None,
97+
strict_mode: bool = True,
98+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
99+
) -> Callable[[ToolFunction[...]], FunctionTool]:
100+
"""Overload for usage as @function_tool(...)."""
101+
...
102+
103+
104+
def durable_function_tool(
105+
func: ToolFunction[...] | None = None,
106+
*,
107+
name_override: str | None = None,
108+
description_override: str | None = None,
109+
docstring_style: DocstringStyle | None = None,
110+
use_docstring_info: bool = True,
111+
failure_error_function: ToolErrorFunction | None = raise_terminal_errors,
112+
strict_mode: bool = True,
113+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
114+
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
115+
failure_fn = propagate_cancellation(failure_error_function)
116+
117+
if callable(func):
118+
return oai_function_tool(
119+
func=func,
120+
name_override=name_override,
121+
description_override=description_override,
122+
docstring_style=docstring_style,
123+
use_docstring_info=use_docstring_info,
124+
failure_error_function=failure_fn,
125+
strict_mode=strict_mode,
126+
is_enabled=is_enabled,
127+
)
128+
else:
129+
return oai_function_tool(
130+
name_override=name_override,
131+
description_override=description_override,
132+
docstring_style=docstring_style,
133+
use_docstring_info=use_docstring_info,
134+
failure_error_function=failure_fn,
135+
strict_mode=strict_mode,
136+
is_enabled=is_enabled,
137+
)
26138

27139

28140
def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]:
@@ -35,11 +147,21 @@ def _create_wrapper(state, captured_tool):
35147
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
36148
turnstile = state.turnstile
37149
call_id = tool_context.tool_call_id
150+
# wait for our turn
151+
await turnstile.wait_for(call_id)
38152
try:
39-
await turnstile.wait_for(call_id)
40-
return await captured_tool.on_invoke_tool(tool_context, tool_input)
41-
finally:
153+
# invoke the original tool
154+
res = await captured_tool.on_invoke_tool(tool_context, tool_input)
155+
# allow the next tool to proceed
42156
turnstile.allow_next_after(call_id)
157+
return res
158+
except BaseException as ex:
159+
# if there was an error, it will be propagated up, towards the handler
160+
# but we need to make sure that all subsequent tools will not execute
161+
# as they might interact with the restate context.
162+
turnstile.cancel_all_after(call_id)
163+
# re-raise the exception
164+
raise ex from None
43165

44166
return on_invoke_tool_wrapper
45167

python/restate/ext/openai/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from agents import (
1818
Usage,
19+
AgentsException,
1920
)
2021
from agents.items import TResponseOutputItem
2122
from agents.items import TResponseInputItem
@@ -24,6 +25,7 @@
2425
from pydantic import BaseModel
2526

2627
from restate.ext.turnstile import Turnstile
28+
from restate import TerminalError
2729

2830

2931
class State:
@@ -77,3 +79,10 @@ class RestateModelResponse(BaseModel):
7779

7880
def to_input_items(self) -> list[TResponseInputItem]:
7981
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore
82+
83+
84+
class AgentsTerminalException(AgentsException, TerminalError):
85+
"""Exception that is both an AgentsException and a restate.TerminalError."""
86+
87+
def __init__(self, *args: object) -> None:
88+
super().__init__(*args)

python/restate/ext/openai/runner_wrapper.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,17 @@
1616

1717
from agents import (
1818
Model,
19-
RunContextWrapper,
20-
AgentsException,
2119
RunConfig,
2220
TContext,
2321
RunResult,
2422
Agent,
25-
ModelBehaviorError,
2623
Runner,
2724
)
2825
from agents.models.multi_provider import MultiProvider
2926
from agents.items import TResponseStreamEvent, ModelResponse
3027
from agents.items import TResponseInputItem
31-
from typing import Any, AsyncIterator
28+
from typing import AsyncIterator
3229

33-
from restate.exceptions import SdkInternalBaseException
3430
from restate.ext.turnstile import Turnstile
3531
from restate.extensions import current_context
3632
from restate import RunOptions, TerminalError
@@ -105,48 +101,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
105101
raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.")
106102

107103

108-
class AgentsTerminalException(AgentsException, TerminalError):
109-
"""Exception that is both an AgentsException and a restate.TerminalError."""
110-
111-
def __init__(self, *args: object) -> None:
112-
super().__init__(*args)
113-
114-
115-
class AgentsSuspension(AgentsException, SdkInternalBaseException):
116-
"""Exception that is both an AgentsException and a restate SdkInternalBaseException."""
117-
118-
def __init__(self, *args: object) -> None:
119-
super().__init__(*args)
120-
121-
122-
def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
123-
"""A custom function to provide a user-friendly error message."""
124-
# Raise terminal errors and cancellations
125-
if isinstance(error, TerminalError):
126-
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
127-
# so we create a new exception that inherits from both
128-
raise AgentsTerminalException(error.message)
129-
130-
if isinstance(error, ModelBehaviorError):
131-
return f"An error occurred while calling the tool: {str(error)}"
132-
133-
raise error
134-
135-
136-
def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
137-
"""A custom function to provide a user-friendly error message."""
138-
# Raise terminal errors and cancellations
139-
if isinstance(error, TerminalError):
140-
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
141-
# so we create a new exception that inherits from both
142-
return f"An error occurred while running the tool: {str(error)}"
143-
144-
if isinstance(error, ModelBehaviorError):
145-
return f"An error occurred while calling the tool: {str(error)}"
146-
147-
raise error
148-
149-
150104
class DurableRunner:
151105
"""
152106
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.

python/restate/ext/turnstile.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from asyncio import Event
2+
from restate.exceptions import SdkInternalException
23

34

45
class Turnstile:
@@ -14,6 +15,7 @@ def __init__(self, ids: list[str]):
1415
self.turns = dict(zip(ids, ids[1:]))
1516
# mapping of id to event that signals when that id's turn is allowed
1617
self.events = {id: Event() for id in ids}
18+
self.canceled = False
1719
if ids:
1820
# make sure that the first id can proceed immediately
1921
event = self.events[ids[0]]
@@ -22,10 +24,25 @@ def __init__(self, ids: list[str]):
2224
async def wait_for(self, id: str) -> None:
2325
event = self.events[id]
2426
await event.wait()
27+
if self.canceled:
28+
raise SdkInternalException() from None
29+
30+
def cancel_all_after(self, id: str) -> None:
31+
self.canceled = True
32+
next_id = self.turns.get(id)
33+
while next_id is not None:
34+
next_event = self.events[next_id]
35+
next_event.set()
36+
next_id = self.turns.get(next_id)
2537

2638
def allow_next_after(self, id: str) -> None:
2739
next_id = self.turns.get(id)
2840
if next_id is None:
2941
return
3042
next_event = self.events[next_id]
3143
next_event.set()
44+
45+
def cancel_all(self) -> None:
46+
self.canceled = True
47+
for event in self.events.values():
48+
event.set()

0 commit comments

Comments
 (0)