diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 65bdbc2..8fda68d 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -12,7 +12,6 @@ ADK plugin implementation for restate. """ -import asyncio import restate from datetime import timedelta @@ -36,6 +35,7 @@ from restate.extensions import current_context from restate.ext.turnstile import Turnstile +from restate.server_context import clear_extension_data, get_extension_data, set_extension_data def _create_turnstile(s: LlmResponse) -> Turnstile: @@ -44,16 +44,57 @@ def _create_turnstile(s: LlmResponse) -> Turnstile: return turnstile +def _invocation_extension_key(invocation_id: str) -> str: + return "adk_" + invocation_id + + +class PluginState: + def __init__(self, model: BaseLlm): + self.model = model + self.turnstile: Turnstile = Turnstile([]) + + def __close__(self): + """Clean up resources.""" + self.turnstile.cancel_all() + + @staticmethod + def create(invocation_id: str, model: BaseLlm, ctx: restate.Context) -> None: + extension_key = _invocation_extension_key(invocation_id) + state = PluginState(model) + set_extension_data(ctx, extension_key, state) + + @staticmethod + def clear(invocation_id: str) -> None: + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + extension_key = _invocation_extension_key(invocation_id) + clear_extension_data(ctx, extension_key) + + @staticmethod + def from_context(invocation_id: str, ctx: restate.Context | None = None) -> "PluginState": + if ctx is None: + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + extension_key = _invocation_extension_key(invocation_id) + state: PluginState | None = get_extension_data(ctx, extension_key) + if state is None: + raise RuntimeError( + "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + ) + return state + + class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" - _models: dict[str, BaseLlm] - _turnstiles: dict[str, Turnstile | None] - def __init__(self, *, max_model_call_retries: int = 10): super().__init__(name="restate_plugin") - self._models = {} - self._turnstiles = {} self._max_model_call_retries = max_model_call_retries async def before_agent_callback( @@ -68,31 +109,15 @@ async def before_agent_callback( Ensure that the agent is invoked within a restate handler and, using a ```with restate_overrides(ctx):``` block. around your agent use.""" ) + model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) - self._models[callback_context.invocation_id] = model - self._turnstiles[callback_context.invocation_id] = None - - id = callback_context.invocation_id - event = ctx.request().attempt_finished_event - - async def release_task(): - """make sure to release resources when the agent finishes""" - try: - await event.wait() - finally: - self._models.pop(id, None) - maybe_turnstile = self._turnstiles.pop(id, None) - if maybe_turnstile is not None: - maybe_turnstile.cancel_all() - - _ = asyncio.create_task(release_task()) + PluginState.create(callback_context.invocation_id, model, ctx) return None async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: - self._models.pop(callback_context.invocation_id, None) - self._turnstiles.pop(callback_context.invocation_id, None) + PluginState.clear(callback_context.invocation_id) return None async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: @@ -103,15 +128,18 @@ async def after_run_callback(self, *, invocation_context: InvocationContext) -> async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: - model = self._models[callback_context.invocation_id] - ctx = current_context() + ctx = current_context() # Ensure we have a Restate context if ctx is None: - raise RuntimeError( - "No Restate context found, the restate plugin must be used from within a restate handler." + raise restate.TerminalError( + """No Restate context found for RestatePlugin. + Ensure that the agent is invoked within a restate handler and, + using a ```with restate_overrides(ctx):``` block. around your agent use.""" ) + state = PluginState.from_context(callback_context.invocation_id, ctx) + model = state.model response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request) turnstile = _create_turnstile(response) - self._turnstiles[callback_context.invocation_id] = turnstile + state.turnstile = turnstile return response async def before_tool_callback( @@ -121,17 +149,11 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - turnstile = self._turnstiles[tool_context.invocation_id] - assert turnstile is not None, "Turnstile not found for tool invocation." - id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile await turnstile.wait_for(id) - - ctx = current_context() - tool_context.session.state["restate_context"] = ctx - return None async def after_tool_callback( @@ -142,12 +164,11 @@ async def after_tool_callback( tool_context: ToolContext, result: dict, ) -> Optional[dict]: - tool_context.session.state.pop("restate_context", None) - turnstile = self._turnstiles[tool_context.invocation_id] - assert turnstile is not None, "Turnstile not found for tool invocation." id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile turnstile.allow_next_after(id) + return None async def on_tool_error_callback( @@ -158,18 +179,12 @@ async def on_tool_error_callback( tool_context: ToolContext, error: Exception, ) -> Optional[dict]: - tool_context.session.state.pop("restate_context", None) - turnstile = self._turnstiles[tool_context.invocation_id] - assert turnstile is not None, "Turnstile not found for tool invocation." id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile turnstile.cancel_all_after(id) return None - async def close(self): - self._models.clear() - self._turnstiles.clear() - def _get_function_call_ids(s: LlmResponse) -> list[str]: ids = [] diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 31ae4be..d10c032 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -17,7 +17,7 @@ """This module contains the restate context implementation based on the server""" import asyncio -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager import contextvars import copy from random import Random @@ -315,6 +315,49 @@ def current_context() -> Context | None: return _restate_context_var.get() +def set_extension_data(ctx: Context, key: str, value: Any) -> None: + """Set extension data in the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + ctx.extension_data[key] = value + + +def get_extension_data(ctx: Context, key: str) -> Any: + """Get extension data from the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + return ctx.extension_data.get(key, None) + + +def clear_extension_data(ctx: Context, key: str) -> None: + """Clear extension data from the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + if key in ctx.extension_data: + del ctx.extension_data[key] + + +@asynccontextmanager +async def auto_close_extension_data(data: Dict[str, Any]): + """Context manager to auto close extension data.""" + try: + yield + finally: + for value in data.values(): + close_method = getattr(value, "__close__", None) + if callable(close_method): + try: + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + close_method() + except Exception: + # extension data close failure should not block further processing + # TODO: add logging here + pass + data.clear() + + # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): """This class implements the context for the restate framework based on the server.""" @@ -339,6 +382,7 @@ def __init__( self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {} self.request_finished_event = asyncio.Event() self.tasks = Tasks() + self.extension_data: Dict[str, Any] = {} async def enter(self): """Invoke the user code.""" @@ -349,6 +393,8 @@ async def enter(self): async with AsyncExitStack() as stack: for manager in self.handler.context_managers or []: await stack.enter_async_context(manager()) + await stack.enter_async_context(auto_close_extension_data(self.extension_data)) + out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer) restate_context_is_replaying.set(False) self.vm.sys_write_output_success(bytes(out_buffer)) @@ -971,3 +1017,11 @@ def attach_invocation( handle = self.vm.attach_invocation(invocation_id) update_restate_context_is_replaying(self.vm) return self.create_future(handle, serde) + + def get_extension_data(self, key: str) -> Any: + """Get extension data by key.""" + return self.extension_data.get(key) + + def set_extension_data(self, key: str, value: Any) -> None: + """Set extension data by key.""" + self.extension_data[key] = value