Skip to content

Commit 18f10ce

Browse files
committed
Add explicit tool execution synchronization
1 parent a4765b0 commit 18f10ce

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

python/restate/ext/adk/plugin.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,25 @@
3535

3636
from restate.extensions import current_context
3737

38+
from restate.ext.turnstile import Turnstile
39+
40+
41+
def _create_turnstile(s: LlmResponse) -> Turnstile:
42+
ids = _get_function_call_ids(s)
43+
turnstile = Turnstile(ids)
44+
return turnstile
45+
3846

3947
class RestatePlugin(BasePlugin):
4048
"""A plugin to integrate Restate with the ADK framework."""
4149

4250
_models: dict[str, BaseLlm]
43-
_locks: dict[str, asyncio.Lock]
51+
_turnstiles: dict[str, Turnstile | None]
4452

4553
def __init__(self, *, max_model_call_retries: int = 10):
4654
super().__init__(name="restate_plugin")
4755
self._models = {}
48-
self._locks = {}
56+
self._turnstiles = {}
4957
self._max_model_call_retries = max_model_call_retries
5058

5159
async def before_agent_callback(
@@ -62,7 +70,7 @@ async def before_agent_callback(
6270
)
6371
model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
6472
self._models[callback_context.invocation_id] = model
65-
self._locks[callback_context.invocation_id] = asyncio.Lock()
73+
self._turnstiles[callback_context.invocation_id] = None
6674

6775
id = callback_context.invocation_id
6876
event = ctx.request().attempt_finished_event
@@ -73,7 +81,7 @@ async def release_task():
7381
await event.wait()
7482
finally:
7583
self._models.pop(id, None)
76-
self._locks.pop(id, None)
84+
self._turnstiles.pop(id, None)
7785

7886
_ = asyncio.create_task(release_task())
7987
return None
@@ -82,7 +90,7 @@ async def after_agent_callback(
8290
self, *, agent: BaseAgent, callback_context: CallbackContext
8391
) -> Optional[types.Content]:
8492
self._models.pop(callback_context.invocation_id, None)
85-
self._locks.pop(callback_context.invocation_id, None)
93+
self._turnstiles.pop(callback_context.invocation_id, None)
8694
return None
8795

8896
async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
@@ -100,6 +108,8 @@ async def before_model_callback(
100108
"No Restate context found, the restate plugin must be used from within a restate handler."
101109
)
102110
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
111+
turnstile = _create_turnstile(response)
112+
self._turnstiles[callback_context.invocation_id] = turnstile
103113
return response
104114

105115
async def before_tool_callback(
@@ -109,11 +119,17 @@ async def before_tool_callback(
109119
tool_args: dict[str, Any],
110120
tool_context: ToolContext,
111121
) -> Optional[dict]:
112-
lock = self._locks[tool_context.invocation_id]
122+
turnstile = self._turnstiles[tool_context.invocation_id]
123+
assert turnstile is not None, "Turnstile not found for tool invocation."
124+
125+
id = tool_context.function_call_id
126+
assert id is not None, "Function call ID is required for tool invocation."
127+
128+
await turnstile.wait_for(id)
129+
113130
ctx = current_context()
114-
await lock.acquire()
115131
tool_context.session.state["restate_context"] = ctx
116-
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
132+
117133
return None
118134

119135
async def after_tool_callback(
@@ -125,8 +141,11 @@ async def after_tool_callback(
125141
result: dict,
126142
) -> Optional[dict]:
127143
tool_context.session.state.pop("restate_context", None)
128-
lock = self._locks[tool_context.invocation_id]
129-
lock.release()
144+
turnstile = self._turnstiles[tool_context.invocation_id]
145+
assert turnstile is not None, "Turnstile not found for tool invocation."
146+
id = tool_context.function_call_id
147+
assert id is not None, "Function call ID is required for tool invocation."
148+
turnstile.allow_next_after(id)
130149
return None
131150

132151
async def on_tool_error_callback(
@@ -138,13 +157,26 @@ async def on_tool_error_callback(
138157
error: Exception,
139158
) -> Optional[dict]:
140159
tool_context.session.state.pop("restate_context", None)
141-
lock = self._locks[tool_context.invocation_id]
142-
lock.release()
160+
turnstile = self._turnstiles[tool_context.invocation_id]
161+
assert turnstile is not None, "Turnstile not found for tool invocation."
162+
id = tool_context.function_call_id
163+
assert id is not None, "Function call ID is required for tool invocation."
164+
turnstile.allow_next_after(id)
143165
return None
144166

145167
async def close(self):
146168
self._models.clear()
147-
self._locks.clear()
169+
self._turnstiles.clear()
170+
171+
172+
def _get_function_call_ids(s: LlmResponse) -> list[str]:
173+
ids = []
174+
if s.content and s.content.parts:
175+
for part in s.content.parts:
176+
if part.function_call:
177+
if part.function_call.id:
178+
ids.append(part.function_call.id)
179+
return ids
148180

149181

150182
def _generate_client_function_call_id(s: LlmResponse) -> None:

python/restate/ext/turnstile.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from asyncio import Event
2+
3+
4+
class Turnstile:
5+
"""A turnstile to manage ordered access based on IDs."""
6+
7+
def __init__(self, ids: list[str]):
8+
# ordered mapping of id to next id in the sequence
9+
# for example:
10+
# {'id1': 'id2', 'id2': 'id3'} <-- id3 is the last.
11+
# {} <-- no ids, no turns.
12+
# {} <-- single id, no next turn.
13+
#
14+
self.turns = dict(zip(ids, ids[1:]))
15+
# mapping of id to event that signals when that id's turn is allowed
16+
self.events = {id: Event() for id in ids}
17+
if ids:
18+
# make sure that the first id can proceed immediately
19+
event = self.events[ids[0]]
20+
event.set()
21+
22+
async def wait_for(self, id: str) -> None:
23+
event = self.events[id]
24+
await event.wait()
25+
26+
def allow_next_after(self, id: str) -> None:
27+
next_id = self.turns.get(id)
28+
if next_id is None:
29+
return
30+
next_event = self.events[next_id]
31+
next_event.set()

0 commit comments

Comments
 (0)