3535
3636from restate .extensions import current_context
3737
38+ from restate .ext .turnstile import Turnstile
39+
40+
41+ def _create_turnstile (s : LlmResponse ) -> Turnstile :
42+ ids = _get_function_call_ids (s )
43+ turnstile = Turnstile (ids )
44+ return turnstile
45+
3846
3947class RestatePlugin (BasePlugin ):
4048 """A plugin to integrate Restate with the ADK framework."""
4149
4250 _models : dict [str , BaseLlm ]
43- _locks : dict [str , asyncio . Lock ]
51+ _turnstiles : dict [str , Turnstile | None ]
4452
4553 def __init__ (self , * , max_model_call_retries : int = 10 ):
4654 super ().__init__ (name = "restate_plugin" )
4755 self ._models = {}
48- self ._locks = {}
56+ self ._turnstiles = {}
4957 self ._max_model_call_retries = max_model_call_retries
5058
5159 async def before_agent_callback (
@@ -62,7 +70,7 @@ async def before_agent_callback(
6270 )
6371 model = agent .model if isinstance (agent .model , BaseLlm ) else LLMRegistry .new_llm (agent .model )
6472 self ._models [callback_context .invocation_id ] = model
65- self ._locks [callback_context .invocation_id ] = asyncio . Lock ()
73+ self ._turnstiles [callback_context .invocation_id ] = None
6674
6775 id = callback_context .invocation_id
6876 event = ctx .request ().attempt_finished_event
@@ -73,7 +81,7 @@ async def release_task():
7381 await event .wait ()
7482 finally :
7583 self ._models .pop (id , None )
76- self ._locks .pop (id , None )
84+ self ._turnstiles .pop (id , None )
7785
7886 _ = asyncio .create_task (release_task ())
7987 return None
@@ -82,7 +90,7 @@ async def after_agent_callback(
8290 self , * , agent : BaseAgent , callback_context : CallbackContext
8391 ) -> Optional [types .Content ]:
8492 self ._models .pop (callback_context .invocation_id , None )
85- self ._locks .pop (callback_context .invocation_id , None )
93+ self ._turnstiles .pop (callback_context .invocation_id , None )
8694 return None
8795
8896 async def after_run_callback (self , * , invocation_context : InvocationContext ) -> None :
@@ -100,6 +108,8 @@ async def before_model_callback(
100108 "No Restate context found, the restate plugin must be used from within a restate handler."
101109 )
102110 response = await _generate_content_async (ctx , self ._max_model_call_retries , model , llm_request )
111+ turnstile = _create_turnstile (response )
112+ self ._turnstiles [callback_context .invocation_id ] = turnstile
103113 return response
104114
105115 async def before_tool_callback (
@@ -109,11 +119,17 @@ async def before_tool_callback(
109119 tool_args : dict [str , Any ],
110120 tool_context : ToolContext ,
111121 ) -> Optional [dict ]:
112- lock = self ._locks [tool_context .invocation_id ]
122+ turnstile = self ._turnstiles [tool_context .invocation_id ]
123+ assert turnstile is not None , "Turnstile not found for tool invocation."
124+
125+ id = tool_context .function_call_id
126+ assert id is not None , "Function call ID is required for tool invocation."
127+
128+ await turnstile .wait_for (id )
129+
113130 ctx = current_context ()
114- await lock .acquire ()
115131 tool_context .session .state ["restate_context" ] = ctx
116- # TODO: if we want we can also automatically wrap tools with ctx.run_typed here
132+
117133 return None
118134
119135 async def after_tool_callback (
@@ -125,8 +141,11 @@ async def after_tool_callback(
125141 result : dict ,
126142 ) -> Optional [dict ]:
127143 tool_context .session .state .pop ("restate_context" , None )
128- lock = self ._locks [tool_context .invocation_id ]
129- lock .release ()
144+ turnstile = self ._turnstiles [tool_context .invocation_id ]
145+ assert turnstile is not None , "Turnstile not found for tool invocation."
146+ id = tool_context .function_call_id
147+ assert id is not None , "Function call ID is required for tool invocation."
148+ turnstile .allow_next_after (id )
130149 return None
131150
132151 async def on_tool_error_callback (
@@ -138,13 +157,26 @@ async def on_tool_error_callback(
138157 error : Exception ,
139158 ) -> Optional [dict ]:
140159 tool_context .session .state .pop ("restate_context" , None )
141- lock = self ._locks [tool_context .invocation_id ]
142- lock .release ()
160+ turnstile = self ._turnstiles [tool_context .invocation_id ]
161+ assert turnstile is not None , "Turnstile not found for tool invocation."
162+ id = tool_context .function_call_id
163+ assert id is not None , "Function call ID is required for tool invocation."
164+ turnstile .allow_next_after (id )
143165 return None
144166
145167 async def close (self ):
146168 self ._models .clear ()
147- self ._locks .clear ()
169+ self ._turnstiles .clear ()
170+
171+
172+ def _get_function_call_ids (s : LlmResponse ) -> list [str ]:
173+ ids = []
174+ if s .content and s .content .parts :
175+ for part in s .content .parts :
176+ if part .function_call :
177+ if part .function_call .id :
178+ ids .append (part .function_call .id )
179+ return ids
148180
149181
150182def _generate_client_function_call_id (s : LlmResponse ) -> None :
0 commit comments