Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 62 additions & 47 deletions python/restate/ext/adk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ADK plugin implementation for restate.
"""

import asyncio
import restate

from datetime import timedelta
Expand All @@ -36,6 +35,7 @@
from restate.extensions import current_context

from restate.ext.turnstile import Turnstile
from restate.server_context import clear_extension_data, get_extension_data, set_extension_data


def _create_turnstile(s: LlmResponse) -> Turnstile:
Expand All @@ -44,16 +44,57 @@ def _create_turnstile(s: LlmResponse) -> Turnstile:
return turnstile


def _invocation_extension_key(invocation_id: str) -> str:
return "adk_" + invocation_id


class PluginState:
def __init__(self, model: BaseLlm):
self.model = model
self.turnstile: Turnstile = Turnstile([])

def __close__(self):
"""Clean up resources."""
self.turnstile.cancel_all()

@staticmethod
def create(invocation_id: str, model: BaseLlm, ctx: restate.Context) -> None:
extension_key = _invocation_extension_key(invocation_id)
state = PluginState(model)
set_extension_data(ctx, extension_key, state)

@staticmethod
def clear(invocation_id: str) -> None:
ctx = current_context()
if ctx is None:
raise RuntimeError(
"No Restate context found, the restate plugin must be used from within a restate handler."
)
extension_key = _invocation_extension_key(invocation_id)
clear_extension_data(ctx, extension_key)

@staticmethod
def from_context(invocation_id: str, ctx: restate.Context | None = None) -> "PluginState":
if ctx is None:
ctx = current_context()
if ctx is None:
raise RuntimeError(
"No Restate context found, the restate plugin must be used from within a restate handler."
)
extension_key = _invocation_extension_key(invocation_id)
state: PluginState | None = get_extension_data(ctx, extension_key)
if state is None:
raise RuntimeError(
"No RestatePlugin state found, the restate plugin must be used from within a restate handler."
)
return state


class RestatePlugin(BasePlugin):
"""A plugin to integrate Restate with the ADK framework."""

_models: dict[str, BaseLlm]
_turnstiles: dict[str, Turnstile | None]

def __init__(self, *, max_model_call_retries: int = 10):
super().__init__(name="restate_plugin")
self._models = {}
self._turnstiles = {}
self._max_model_call_retries = max_model_call_retries

async def before_agent_callback(
Expand All @@ -68,31 +109,15 @@ async def before_agent_callback(
Ensure that the agent is invoked within a restate handler and,
using a ```with restate_overrides(ctx):``` block. around your agent use."""
)

model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
self._models[callback_context.invocation_id] = model
self._turnstiles[callback_context.invocation_id] = None

id = callback_context.invocation_id
event = ctx.request().attempt_finished_event

async def release_task():
"""make sure to release resources when the agent finishes"""
try:
await event.wait()
finally:
self._models.pop(id, None)
maybe_turnstile = self._turnstiles.pop(id, None)
if maybe_turnstile is not None:
maybe_turnstile.cancel_all()

_ = asyncio.create_task(release_task())
PluginState.create(callback_context.invocation_id, model, ctx)
return None

async def after_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> Optional[types.Content]:
self._models.pop(callback_context.invocation_id, None)
self._turnstiles.pop(callback_context.invocation_id, None)
PluginState.clear(callback_context.invocation_id)
return None

async def after_run_callback(self, *, invocation_context: InvocationContext) -> None:
Expand All @@ -103,15 +128,18 @@ async def after_run_callback(self, *, invocation_context: InvocationContext) ->
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
model = self._models[callback_context.invocation_id]
ctx = current_context()
ctx = current_context() # Ensure we have a Restate context
if ctx is None:
raise RuntimeError(
"No Restate context found, the restate plugin must be used from within a restate handler."
raise restate.TerminalError(
"""No Restate context found for RestatePlugin.
Ensure that the agent is invoked within a restate handler and,
using a ```with restate_overrides(ctx):``` block. around your agent use."""
)
state = PluginState.from_context(callback_context.invocation_id, ctx)
model = state.model
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
turnstile = _create_turnstile(response)
self._turnstiles[callback_context.invocation_id] = turnstile
state.turnstile = turnstile
return response

async def before_tool_callback(
Expand All @@ -121,17 +149,11 @@ async def before_tool_callback(
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> Optional[dict]:
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."

id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."

turnstile = PluginState.from_context(tool_context.invocation_id).turnstile
await turnstile.wait_for(id)

ctx = current_context()
tool_context.session.state["restate_context"] = ctx

return None

async def after_tool_callback(
Expand All @@ -142,12 +164,11 @@ async def after_tool_callback(
tool_context: ToolContext,
result: dict,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."
id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."
turnstile = PluginState.from_context(tool_context.invocation_id).turnstile
turnstile.allow_next_after(id)

return None

async def on_tool_error_callback(
Expand All @@ -158,18 +179,12 @@ async def on_tool_error_callback(
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
tool_context.session.state.pop("restate_context", None)
turnstile = self._turnstiles[tool_context.invocation_id]
assert turnstile is not None, "Turnstile not found for tool invocation."
id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."
turnstile = PluginState.from_context(tool_context.invocation_id).turnstile
turnstile.cancel_all_after(id)
return None

async def close(self):
self._models.clear()
self._turnstiles.clear()


def _get_function_call_ids(s: LlmResponse) -> list[str]:
ids = []
Expand Down
56 changes: 55 additions & 1 deletion python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""This module contains the restate context implementation based on the server"""

import asyncio
from contextlib import AsyncExitStack
from contextlib import AsyncExitStack, asynccontextmanager
import contextvars
import copy
from random import Random
Expand Down Expand Up @@ -315,6 +315,49 @@ def current_context() -> Context | None:
return _restate_context_var.get()


def set_extension_data(ctx: Context, key: str, value: Any) -> None:
"""Set extension data in the current context."""
if not isinstance(ctx, ServerInvocationContext):
raise RuntimeError("Current context is not a ServerInvocationContext")
ctx.extension_data[key] = value


def get_extension_data(ctx: Context, key: str) -> Any:
"""Get extension data from the current context."""
if not isinstance(ctx, ServerInvocationContext):
raise RuntimeError("Current context is not a ServerInvocationContext")
return ctx.extension_data.get(key, None)


def clear_extension_data(ctx: Context, key: str) -> None:
"""Clear extension data from the current context."""
if not isinstance(ctx, ServerInvocationContext):
raise RuntimeError("Current context is not a ServerInvocationContext")
if key in ctx.extension_data:
del ctx.extension_data[key]


@asynccontextmanager
async def auto_close_extension_data(data: Dict[str, Any]):
"""Context manager to auto close extension data."""
try:
yield
finally:
for value in data.values():
close_method = getattr(value, "__close__", None)
if callable(close_method):
try:
if inspect.iscoroutinefunction(close_method):
await close_method()
else:
close_method()
except Exception:
# extension data close failure should not block further processing
# TODO: add logging here
pass
data.clear()


# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""
Expand All @@ -339,6 +382,7 @@ def __init__(
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
self.request_finished_event = asyncio.Event()
self.tasks = Tasks()
self.extension_data: Dict[str, Any] = {}

async def enter(self):
"""Invoke the user code."""
Expand All @@ -349,6 +393,8 @@ async def enter(self):
async with AsyncExitStack() as stack:
for manager in self.handler.context_managers or []:
await stack.enter_async_context(manager())
await stack.enter_async_context(auto_close_extension_data(self.extension_data))

out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer)
restate_context_is_replaying.set(False)
self.vm.sys_write_output_success(bytes(out_buffer))
Expand Down Expand Up @@ -971,3 +1017,11 @@ def attach_invocation(
handle = self.vm.attach_invocation(invocation_id)
update_restate_context_is_replaying(self.vm)
return self.create_future(handle, serde)

def get_extension_data(self, key: str) -> Any:
"""Get extension data by key."""
return self.extension_data.get(key)

def set_extension_data(self, key: str, value: Any) -> None:
"""Set extension data by key."""
self.extension_data[key] = value