Skip to content

Commit f3e0382

Browse files
authored
Add explicit tool execution synchronization (#165)
* Add explicit tool execution synchronization * Integrate turnstile into OpenAI ext * Do not auto-wrap tools with ctx.run * rename utils to functions
1 parent a32145b commit f3e0382

File tree

7 files changed

+338
-128
lines changed

7 files changed

+338
-128
lines changed

python/restate/ext/adk/plugin.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,25 @@
3535

3636
from restate.extensions import current_context
3737

38+
from restate.ext.turnstile import Turnstile
39+
40+
41+
def _create_turnstile(s: LlmResponse) -> Turnstile:
42+
ids = _get_function_call_ids(s)
43+
turnstile = Turnstile(ids)
44+
return turnstile
45+
3846

3947
class RestatePlugin(BasePlugin):
4048
"""A plugin to integrate Restate with the ADK framework."""
4149

4250
_models: dict[str, BaseLlm]
43-
_locks: dict[str, asyncio.Lock]
51+
_turnstiles: dict[str, Turnstile | None]
4452

4553
def __init__(self, *, max_model_call_retries: int = 10):
4654
super().__init__(name="restate_plugin")
4755
self._models = {}
48-
self._locks = {}
56+
self._turnstiles = {}
4957
self._max_model_call_retries = max_model_call_retries
5058

5159
async def before_agent_callback(
@@ -62,7 +70,7 @@ async def before_agent_callback(
6270
)
6371
model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
6472
self._models[callback_context.invocation_id] = model
65-
self._locks[callback_context.invocation_id] = asyncio.Lock()
73+
self._turnstiles[callback_context.invocation_id] = None
6674

6775
id = callback_context.invocation_id
6876
event = ctx.request().attempt_finished_event
@@ -73,7 +81,7 @@ async def release_task():
7381
await event.wait()
7482
finally:
7583
self._models.pop(id, None)
76-
self._locks.pop(id, None)
84+
self._turnstiles.pop(id, None)
7785

7886
_ = asyncio.create_task(release_task())
7987
return None
@@ -82,7 +90,7 @@ async def after_agent_callback(
8290
self, *, agent: BaseAgent, callback_context: CallbackContext
8391
) -> Optional[types.Content]:
8492
self._models.pop(callback_context.invocation_id, None)
85-
self._locks.pop(callback_context.invocation_id, None)
93+
self._turnstiles.pop(callback_context.invocation_id, None)
8694
return None
8795

8896
async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
@@ -100,6 +108,8 @@ async def before_model_callback(
100108
"No Restate context found, the restate plugin must be used from within a restate handler."
101109
)
102110
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
111+
turnstile = _create_turnstile(response)
112+
self._turnstiles[callback_context.invocation_id] = turnstile
103113
return response
104114

105115
async def before_tool_callback(
@@ -109,11 +119,17 @@ async def before_tool_callback(
109119
tool_args: dict[str, Any],
110120
tool_context: ToolContext,
111121
) -> Optional[dict]:
112-
lock = self._locks[tool_context.invocation_id]
122+
turnstile = self._turnstiles[tool_context.invocation_id]
123+
assert turnstile is not None, "Turnstile not found for tool invocation."
124+
125+
id = tool_context.function_call_id
126+
assert id is not None, "Function call ID is required for tool invocation."
127+
128+
await turnstile.wait_for(id)
129+
113130
ctx = current_context()
114-
await lock.acquire()
115131
tool_context.session.state["restate_context"] = ctx
116-
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
132+
117133
return None
118134

119135
async def after_tool_callback(
@@ -125,8 +141,11 @@ async def after_tool_callback(
125141
result: dict,
126142
) -> Optional[dict]:
127143
tool_context.session.state.pop("restate_context", None)
128-
lock = self._locks[tool_context.invocation_id]
129-
lock.release()
144+
turnstile = self._turnstiles[tool_context.invocation_id]
145+
assert turnstile is not None, "Turnstile not found for tool invocation."
146+
id = tool_context.function_call_id
147+
assert id is not None, "Function call ID is required for tool invocation."
148+
turnstile.allow_next_after(id)
130149
return None
131150

132151
async def on_tool_error_callback(
@@ -138,13 +157,26 @@ async def on_tool_error_callback(
138157
error: Exception,
139158
) -> Optional[dict]:
140159
tool_context.session.state.pop("restate_context", None)
141-
lock = self._locks[tool_context.invocation_id]
142-
lock.release()
160+
turnstile = self._turnstiles[tool_context.invocation_id]
161+
assert turnstile is not None, "Turnstile not found for tool invocation."
162+
id = tool_context.function_call_id
163+
assert id is not None, "Function call ID is required for tool invocation."
164+
turnstile.allow_next_after(id)
143165
return None
144166

145167
async def close(self):
146168
self._models.clear()
147-
self._locks.clear()
169+
self._turnstiles.clear()
170+
171+
172+
def _get_function_call_ids(s: LlmResponse) -> list[str]:
173+
ids = []
174+
if s.content and s.content.parts:
175+
for part in s.content.parts:
176+
if part.function_call:
177+
if part.function_call.id:
178+
ids.append(part.function_call.id)
179+
return ids
148180

149181

150182
def _generate_client_function_call_id(s: LlmResponse) -> None:

python/restate/ext/openai/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import typing
1616

17-
from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts
1817
from restate import ObjectContext, Context
1918
from restate.server_context import current_context
2019

20+
from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors
21+
from .models import LlmRetryOpts
22+
2123

2224
def restate_object_context() -> ObjectContext:
2325
"""Get the current Restate ObjectContext."""
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
12+
from typing import List, Any
13+
import dataclasses
14+
15+
from agents import (
16+
Handoff,
17+
TContext,
18+
Agent,
19+
)
20+
21+
from agents.tool import FunctionTool, Tool
22+
from agents.tool_context import ToolContext
23+
from agents.items import TResponseOutputItem
24+
25+
from .models import State
26+
27+
28+
def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]:
29+
"""Extract function call IDs from the model response."""
30+
# TODO: support function calls in other response types
31+
return [item.call_id for item in response if item.type == "function_call"]
32+
33+
34+
def _create_wrapper(state, captured_tool):
35+
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
36+
turnstile = state.turnstile
37+
call_id = tool_context.tool_call_id
38+
try:
39+
await turnstile.wait_for(call_id)
40+
return await captured_tool.on_invoke_tool(tool_context, tool_input)
41+
finally:
42+
turnstile.allow_next_after(call_id)
43+
44+
return on_invoke_tool_wrapper
45+
46+
47+
def wrap_agent_tools(
48+
agent: Agent[TContext],
49+
state: State,
50+
) -> Agent[TContext]:
51+
"""
52+
Wrap the tools of an agent to use the Restate error handling.
53+
54+
Returns:
55+
A new agent with wrapped tools.
56+
"""
57+
wrapped_tools: list[Tool] = []
58+
for tool in agent.tools:
59+
if isinstance(tool, FunctionTool):
60+
wrapped = _create_wrapper(state, tool)
61+
wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=wrapped))
62+
else:
63+
wrapped_tools.append(tool)
64+
65+
wrapped_handoffs: list[Agent[Any] | Handoff[Any]] = []
66+
for handoff in agent.handoffs:
67+
if isinstance(handoff, Agent):
68+
wrapped_handoff = wrap_agent_tools(handoff, state)
69+
wrapped_handoffs.append(wrapped_handoff)
70+
elif isinstance(handoff, Handoff):
71+
wrapped_handoffs.append(wrap_agent_handoff_tools(handoff, state))
72+
else:
73+
raise TypeError(f"Unsupported handoff type: {type(handoff)}")
74+
75+
return agent.clone(tools=wrapped_tools, handoffs=wrapped_handoffs)
76+
77+
78+
def wrap_agent_handoff_tools(
79+
handoff: Handoff[TContext],
80+
state: State,
81+
) -> Handoff[TContext]:
82+
"""
83+
Wrap the tools of a handoff to use the Restate error handling.
84+
85+
Returns:
86+
A new handoff with wrapped tools.
87+
"""
88+
89+
original_on_invoke_handoff = handoff.on_invoke_handoff
90+
91+
async def wrapped(*args, **kwargs) -> Any:
92+
agent = await original_on_invoke_handoff(*args, **kwargs)
93+
return wrap_agent_tools(agent, state)
94+
95+
return dataclasses.replace(handoff, on_invoke_handoff=wrapped)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
This module contains the optional OpenAI integration for Restate.
13+
"""
14+
15+
import dataclasses
16+
17+
from agents import (
18+
Usage,
19+
)
20+
from agents.items import TResponseOutputItem
21+
from agents.items import TResponseInputItem
22+
from datetime import timedelta
23+
from typing import Optional
24+
from pydantic import BaseModel
25+
26+
from restate.ext.turnstile import Turnstile
27+
28+
29+
class State:
30+
__slots__ = ("turnstile",)
31+
32+
def __init__(self) -> None:
33+
self.turnstile = Turnstile([])
34+
35+
36+
@dataclasses.dataclass
37+
class LlmRetryOpts:
38+
max_attempts: Optional[int] = 10
39+
"""Max number of attempts (including the initial), before giving up.
40+
41+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
42+
max_duration: Optional[timedelta] = None
43+
"""Max duration of retries, before giving up.
44+
45+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
46+
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
47+
"""Initial interval for the first retry attempt.
48+
Retry interval will grow by a factor specified in `retry_interval_factor`.
49+
50+
If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
51+
max_retry_interval: Optional[timedelta] = None
52+
"""Max interval between retries.
53+
Retry interval will grow by a factor specified in `retry_interval_factor`.
54+
55+
The default is 10 seconds."""
56+
retry_interval_factor: Optional[float] = None
57+
"""Exponentiation factor to use when computing the next retry delay.
58+
59+
If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
60+
61+
62+
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
63+
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
64+
class RestateModelResponse(BaseModel):
65+
output: list[TResponseOutputItem]
66+
"""A list of outputs (messages, tool calls, etc) generated by the model"""
67+
68+
usage: Usage
69+
"""The usage information for the response."""
70+
71+
response_id: str | None
72+
"""An ID for the response which can be used to refer to the response in subsequent calls to the
73+
model. Not supported by all model providers.
74+
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
75+
be passed to `Runner.run`.
76+
"""
77+
78+
def to_input_items(self) -> list[TResponseInputItem]:
79+
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore

0 commit comments

Comments
 (0)