1212This module contains the optional OpenAI integration for Restate.
1313"""
1414
15- import asyncio
1615import dataclasses
1716import typing
1817
4140from agents .tool_context import ToolContext
4241from pydantic import BaseModel
4342from restate .exceptions import SdkInternalBaseException
43+ from restate .ext .turnstile import Turnstile
4444from restate .extensions import current_context
4545
4646from restate import RunOptions , ObjectContext , TerminalError
4747
4848
49+ class State :
50+ __slots__ = ("turnstile" ,)
51+
52+ def __init__ (self ) -> None :
53+ self .turnstile = Turnstile ([])
54+
55+
4956# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5057# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
5158class RestateModelResponse (BaseModel ):
@@ -71,23 +78,25 @@ class DurableModelCalls(MultiProvider):
7178 A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7279 """
7380
74- def __init__ (self , max_retries : int | None = 3 ):
81+ def __init__ (self , state : State , max_retries : int | None = 3 ):
7582 super ().__init__ ()
7683 self .max_retries = max_retries
84+ self .state = state
7785
7886 def get_model (self , model_name : str | None ) -> Model :
79- return RestateModelWrapper (super ().get_model (model_name or None ), self .max_retries )
87+ return RestateModelWrapper (super ().get_model (model_name or None ), self .state , self . max_retries )
8088
8189
8290class RestateModelWrapper (Model ):
8391 """
8492 A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
8593 """
8694
87- def __init__ (self , model : Model , max_retries : int | None = 3 ):
95+ def __init__ (self , model : Model , state : State , max_retries : int | None = 3 ):
8896 self .model = model
8997 self .model_name = "RestateModelWrapper"
9098 self .max_retries = max_retries
99+ self .state = state
91100
92101 async def get_response (self , * args , ** kwargs ) -> ModelResponse :
93102 async def call_llm () -> RestateModelResponse :
@@ -104,6 +113,15 @@ async def call_llm() -> RestateModelResponse:
104113 raise RuntimeError ("No current Restate context found, make sure to run inside a Restate handler" )
105114 result = await ctx .run_typed ("call LLM" , call_llm , RunOptions (max_attempts = self .max_retries ))
106115 # convert back to original ModelResponse
116+ # find all the tool calls
117+ ids = []
118+ for item in result .output :
119+ if item .type == "function_call" :
120+ ids .append (item .call_id )
121+ # TODO: handle other types of tool calls if any are added in the future
122+
123+ self .state .turnstile = Turnstile (ids )
124+
107125 return ModelResponse (
108126 output = result .output ,
109127 usage = result .usage ,
@@ -212,38 +230,33 @@ async def run(
212230 Returns:
213231 The result from Runner.run
214232 """
215-
233+ state = State ()
216234 current_run_config = run_config or RunConfig ()
217235 new_run_config = dataclasses .replace (
218236 current_run_config ,
219- model_provider = DurableModelCalls (),
237+ model_provider = DurableModelCalls (state ),
220238 )
221- restate_agent = sequentialize_and_wrap_tools (starting_agent , disable_tool_autowrapping )
239+ restate_agent = sequentialize_and_wrap_tools (starting_agent , disable_tool_autowrapping , state )
222240 return await OpenAIRunner .run (restate_agent , * args , run_config = new_run_config , ** kwargs )
223241
224242
225243def sequentialize_and_wrap_tools (
226244 agent : Agent [TContext ],
227245 disable_tool_autowrapping : bool ,
246+ state : State ,
228247) -> Agent [TContext ]:
229248 """
230249 Wrap the tools of an agent to use the Restate error handling.
231250
232251 Returns:
233252 A new agent with wrapped tools.
234253 """
235-
236- # Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237- # This lock only affects tools for this agent; handoff agents are wrapped recursively.
238- sequential_tools_lock = asyncio .Lock ()
239254 wrapped_tools : list [Tool ] = []
240255 for tool in agent .tools :
241256 if isinstance (tool , FunctionTool ):
242257
243258 def create_wrapper (captured_tool ):
244259 async def on_invoke_tool_wrapper (tool_context : ToolContext [Any ], tool_input : Any ) -> Any :
245- await sequential_tools_lock .acquire ()
246-
247260 async def invoke ():
248261 result = await captured_tool .on_invoke_tool (tool_context , tool_input )
249262 # Ensure Pydantic objects are serialized to dict for LLM compatibility
@@ -253,7 +266,10 @@ async def invoke():
253266 return result .dict ()
254267 return result
255268
269+ turnstile = state .turnstile
270+ call_id = tool_context .tool_call_id
256271 try :
272+ await turnstile .wait_for (call_id )
257273 if disable_tool_autowrapping :
258274 return await invoke ()
259275
@@ -264,7 +280,7 @@ async def invoke():
264280 )
265281 return await ctx .run_typed (captured_tool .name , invoke )
266282 finally :
267- sequential_tools_lock . release ( )
283+ turnstile . allow_next_after ( call_id )
268284
269285 return on_invoke_tool_wrapper
270286
0 commit comments