Skip to content

Commit 73a79f2

Browse files
committed
Integrate turnstile into OpenAI ext
1 parent 18f10ce commit 73a79f2

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

python/restate/ext/openai/runner_wrapper.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
This module contains the optional OpenAI integration for Restate.
1313
"""
1414

15-
import asyncio
1615
import dataclasses
1716
import typing
1817

@@ -41,11 +40,19 @@
4140
from agents.tool_context import ToolContext
4241
from pydantic import BaseModel
4342
from restate.exceptions import SdkInternalBaseException
43+
from restate.ext.turnstile import Turnstile
4444
from restate.extensions import current_context
4545

4646
from restate import RunOptions, ObjectContext, TerminalError
4747

4848

49+
class State:
50+
__slots__ = ("turnstile",)
51+
52+
def __init__(self) -> None:
53+
self.turnstile = Turnstile([])
54+
55+
4956
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5057
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
5158
class RestateModelResponse(BaseModel):
@@ -71,23 +78,25 @@ class DurableModelCalls(MultiProvider):
7178
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7279
"""
7380

74-
def __init__(self, max_retries: int | None = 3):
81+
def __init__(self, state: State, max_retries: int | None = 3):
7582
super().__init__()
7683
self.max_retries = max_retries
84+
self.state = state
7785

7886
def get_model(self, model_name: str | None) -> Model:
79-
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)
87+
return RestateModelWrapper(super().get_model(model_name or None), self.state, self.max_retries)
8088

8189

8290
class RestateModelWrapper(Model):
8391
"""
8492
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
8593
"""
8694

87-
def __init__(self, model: Model, max_retries: int | None = 3):
95+
def __init__(self, model: Model, state: State, max_retries: int | None = 3):
8896
self.model = model
8997
self.model_name = "RestateModelWrapper"
9098
self.max_retries = max_retries
99+
self.state = state
91100

92101
async def get_response(self, *args, **kwargs) -> ModelResponse:
93102
async def call_llm() -> RestateModelResponse:
@@ -104,6 +113,15 @@ async def call_llm() -> RestateModelResponse:
104113
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
105114
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
106115
# convert back to original ModelResponse
116+
# find all the tool calls
117+
ids = []
118+
for item in result.output:
119+
if item.type == "function_call":
120+
ids.append(item.call_id)
121+
# TODO: handle other types of tool calls if any are added in the future
122+
123+
self.state.turnstile = Turnstile(ids)
124+
107125
return ModelResponse(
108126
output=result.output,
109127
usage=result.usage,
@@ -212,38 +230,33 @@ async def run(
212230
Returns:
213231
The result from Runner.run
214232
"""
215-
233+
state = State()
216234
current_run_config = run_config or RunConfig()
217235
new_run_config = dataclasses.replace(
218236
current_run_config,
219-
model_provider=DurableModelCalls(),
237+
model_provider=DurableModelCalls(state),
220238
)
221-
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
239+
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping, state)
222240
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)
223241

224242

225243
def sequentialize_and_wrap_tools(
226244
agent: Agent[TContext],
227245
disable_tool_autowrapping: bool,
246+
state: State,
228247
) -> Agent[TContext]:
229248
"""
230249
Wrap the tools of an agent to use the Restate error handling.
231250
232251
Returns:
233252
A new agent with wrapped tools.
234253
"""
235-
236-
# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237-
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
238-
sequential_tools_lock = asyncio.Lock()
239254
wrapped_tools: list[Tool] = []
240255
for tool in agent.tools:
241256
if isinstance(tool, FunctionTool):
242257

243258
def create_wrapper(captured_tool):
244259
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
245-
await sequential_tools_lock.acquire()
246-
247260
async def invoke():
248261
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
249262
# Ensure Pydantic objects are serialized to dict for LLM compatibility
@@ -253,7 +266,10 @@ async def invoke():
253266
return result.dict()
254267
return result
255268

269+
turnstile = state.turnstile
270+
call_id = tool_context.tool_call_id
256271
try:
272+
await turnstile.wait_for(call_id)
257273
if disable_tool_autowrapping:
258274
return await invoke()
259275

@@ -264,7 +280,7 @@ async def invoke():
264280
)
265281
return await ctx.run_typed(captured_tool.name, invoke)
266282
finally:
267-
sequential_tools_lock.release()
283+
turnstile.allow_next_after(call_id)
268284

269285
return on_invoke_tool_wrapper
270286

0 commit comments

Comments
 (0)