99# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010#
1111
12- from typing import List , Any
12+ from typing import List , Any , overload , Callable , Union , TypeVar , Awaitable
1313import dataclasses
1414
1515from agents import (
1616 Handoff ,
1717 TContext ,
1818 Agent ,
19+ RunContextWrapper ,
20+ ModelBehaviorError ,
21+ AgentBase ,
1922)
2023
21- from agents .tool import FunctionTool , Tool
24+ from agents .function_schema import DocstringStyle
25+
26+ from agents .tool import (
27+ FunctionTool ,
28+ Tool ,
29+ ToolFunction ,
30+ ToolErrorFunction ,
31+ function_tool as oai_function_tool ,
32+ default_tool_error_function ,
33+ )
2234from agents .tool_context import ToolContext
2335from agents .items import TResponseOutputItem
2436
25- from .models import State
37+ from restate import TerminalError
38+
39+ from .models import State , AgentsTerminalException
40+
41+ T = TypeVar ("T" )
42+
43+ MaybeAwaitable = Union [Awaitable [T ], T ]
44+
45+
46+ def raise_terminal_errors (context : RunContextWrapper [Any ], error : Exception ) -> str :
47+ """A custom function to provide a user-friendly error message."""
48+ # Raise terminal errors and cancellations
49+ if isinstance (error , TerminalError ):
50+ # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
51+ # so we create a new exception that inherits from both
52+ raise AgentsTerminalException (error .message )
53+
54+ if isinstance (error , ModelBehaviorError ):
55+ return f"An error occurred while calling the tool: { str (error )} "
56+
57+ raise error
58+
59+
60+ def propagate_cancellation (failure_error_function : ToolErrorFunction | None = None ) -> ToolErrorFunction :
61+ _fn = failure_error_function if failure_error_function is not None else default_tool_error_function
62+
63+ def inner (context : RunContextWrapper [Any ], error : Exception ):
64+ """Raise cancellations as exceptions."""
65+ if isinstance (error , TerminalError ):
66+ if error .status_code == 409 :
67+ raise error from None
68+ return _fn (context , error )
69+
70+ return inner
71+
72+
73+ @overload
74+ def durable_function_tool (
75+ func : ToolFunction [...],
76+ * ,
77+ name_override : str | None = None ,
78+ description_override : str | None = None ,
79+ docstring_style : DocstringStyle | None = None ,
80+ use_docstring_info : bool = True ,
81+ failure_error_function : ToolErrorFunction | None = None ,
82+ strict_mode : bool = True ,
83+ is_enabled : bool | Callable [[RunContextWrapper [Any ], AgentBase ], MaybeAwaitable [bool ]] = True ,
84+ ) -> FunctionTool :
85+ """Overload for usage as @function_tool (no parentheses)."""
86+ ...
87+
88+
89+ @overload
90+ def durable_function_tool (
91+ * ,
92+ name_override : str | None = None ,
93+ description_override : str | None = None ,
94+ docstring_style : DocstringStyle | None = None ,
95+ use_docstring_info : bool = True ,
96+ failure_error_function : ToolErrorFunction | None = None ,
97+ strict_mode : bool = True ,
98+ is_enabled : bool | Callable [[RunContextWrapper [Any ], AgentBase ], MaybeAwaitable [bool ]] = True ,
99+ ) -> Callable [[ToolFunction [...]], FunctionTool ]:
100+ """Overload for usage as @function_tool(...)."""
101+ ...
102+
103+
104+ def durable_function_tool (
105+ func : ToolFunction [...] | None = None ,
106+ * ,
107+ name_override : str | None = None ,
108+ description_override : str | None = None ,
109+ docstring_style : DocstringStyle | None = None ,
110+ use_docstring_info : bool = True ,
111+ failure_error_function : ToolErrorFunction | None = raise_terminal_errors ,
112+ strict_mode : bool = True ,
113+ is_enabled : bool | Callable [[RunContextWrapper [Any ], AgentBase ], MaybeAwaitable [bool ]] = True ,
114+ ) -> FunctionTool | Callable [[ToolFunction [...]], FunctionTool ]:
115+ failure_fn = propagate_cancellation (failure_error_function )
116+
117+ if callable (func ):
118+ return oai_function_tool (
119+ func = func ,
120+ name_override = name_override ,
121+ description_override = description_override ,
122+ docstring_style = docstring_style ,
123+ use_docstring_info = use_docstring_info ,
124+ failure_error_function = failure_fn ,
125+ strict_mode = strict_mode ,
126+ is_enabled = is_enabled ,
127+ )
128+ else :
129+ return oai_function_tool (
130+ name_override = name_override ,
131+ description_override = description_override ,
132+ docstring_style = docstring_style ,
133+ use_docstring_info = use_docstring_info ,
134+ failure_error_function = failure_fn ,
135+ strict_mode = strict_mode ,
136+ is_enabled = is_enabled ,
137+ )
26138
27139
28140def get_function_call_ids (response : list [TResponseOutputItem ]) -> List [str ]:
@@ -35,11 +147,21 @@ def _create_wrapper(state, captured_tool):
35147 async def on_invoke_tool_wrapper (tool_context : ToolContext [Any ], tool_input : Any ) -> Any :
36148 turnstile = state .turnstile
37149 call_id = tool_context .tool_call_id
150+ # wait for our turn
151+ await turnstile .wait_for (call_id )
38152 try :
39- await turnstile . wait_for ( call_id )
40- return await captured_tool .on_invoke_tool (tool_context , tool_input )
41- finally :
153+ # invoke the original tool
154+ res = await captured_tool .on_invoke_tool (tool_context , tool_input )
155+ # allow the next tool to proceed
42156 turnstile .allow_next_after (call_id )
157+ return res
158+ except BaseException as ex :
159+ # if there was an error, it will be propagated up, towards the handler
160+ # but we need to make sure that all subsequent tools will not execute
161+ # as they might interact with the restate context.
162+ turnstile .cancel_all_after (call_id )
163+ # re-raise the exception
164+ raise ex from None
43165
44166 return on_invoke_tool_wrapper
45167
0 commit comments