diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 0911190..d9d9c82 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -35,17 +35,25 @@ from restate.extensions import current_context +from restate.ext.turnstile import Turnstile + + +def _create_turnstile(s: LlmResponse) -> Turnstile: + ids = _get_function_call_ids(s) + turnstile = Turnstile(ids) + return turnstile + class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" _models: dict[str, BaseLlm] - _locks: dict[str, asyncio.Lock] + _turnstiles: dict[str, Turnstile | None] def __init__(self, *, max_model_call_retries: int = 10): super().__init__(name="restate_plugin") self._models = {} - self._locks = {} + self._turnstiles = {} self._max_model_call_retries = max_model_call_retries async def before_agent_callback( @@ -62,7 +70,7 @@ async def before_agent_callback( ) model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) self._models[callback_context.invocation_id] = model - self._locks[callback_context.invocation_id] = asyncio.Lock() + self._turnstiles[callback_context.invocation_id] = None id = callback_context.invocation_id event = ctx.request().attempt_finished_event @@ -73,7 +81,7 @@ async def release_task(): await event.wait() finally: self._models.pop(id, None) - self._locks.pop(id, None) + self._turnstiles.pop(id, None) _ = asyncio.create_task(release_task()) return None @@ -82,7 +90,7 @@ async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: self._models.pop(callback_context.invocation_id, None) - self._locks.pop(callback_context.invocation_id, None) + self._turnstiles.pop(callback_context.invocation_id, None) return None async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: @@ -100,6 +108,8 @@ async def before_model_callback( "No Restate context found, the restate plugin must be used from within a restate handler." ) response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request) + turnstile = _create_turnstile(response) + self._turnstiles[callback_context.invocation_id] = turnstile return response async def before_tool_callback( @@ -109,11 +119,17 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - lock = self._locks[tool_context.invocation_id] + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + + await turnstile.wait_for(id) + ctx = current_context() - await lock.acquire() tool_context.session.state["restate_context"] = ctx - # TODO: if we want we can also automatically wrap tools with ctx.run_typed here + return None async def after_tool_callback( @@ -125,8 +141,11 @@ async def after_tool_callback( result: dict, ) -> Optional[dict]: tool_context.session.state.pop("restate_context", None) - lock = self._locks[tool_context.invocation_id] - lock.release() + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + turnstile.allow_next_after(id) return None async def on_tool_error_callback( @@ -138,13 +157,26 @@ async def on_tool_error_callback( error: Exception, ) -> Optional[dict]: tool_context.session.state.pop("restate_context", None) - lock = self._locks[tool_context.invocation_id] - lock.release() + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + turnstile.allow_next_after(id) return None async def close(self): self._models.clear() - self._locks.clear() + self._turnstiles.clear() + + +def _get_function_call_ids(s: LlmResponse) -> list[str]: + ids = [] + if s.content and s.content.parts: + for part in s.content.parts: + if part.function_call: + if part.function_call.id: + ids.append(part.function_call.id) + return ids def _generate_client_function_call_id(s: LlmResponse) -> None: diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index 969d42f..a13ac9f 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -14,10 +14,12 @@ import typing -from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts from restate import ObjectContext, Context from restate.server_context import current_context +from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors +from .models import LlmRetryOpts + def restate_object_context() -> ObjectContext: """Get the current Restate ObjectContext.""" diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py new file mode 100644 index 0000000..6fee0ce --- /dev/null +++ b/python/restate/ext/openai/functions.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# + +from typing import List, Any +import dataclasses + +from agents import ( + Handoff, + TContext, + Agent, +) + +from agents.tool import FunctionTool, Tool +from agents.tool_context import ToolContext +from agents.items import TResponseOutputItem + +from .models import State + + +def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: + """Extract function call IDs from the model response.""" + # TODO: support function calls in other response types + return [item.call_id for item in response if item.type == "function_call"] + + +def _create_wrapper(state, captured_tool): + async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: + turnstile = state.turnstile + call_id = tool_context.tool_call_id + try: + await turnstile.wait_for(call_id) + return await captured_tool.on_invoke_tool(tool_context, tool_input) + finally: + turnstile.allow_next_after(call_id) + + return on_invoke_tool_wrapper + + +def wrap_agent_tools( + agent: Agent[TContext], + state: State, +) -> Agent[TContext]: + """ + Wrap the tools of an agent to use the Restate error handling. + + Returns: + A new agent with wrapped tools. + """ + wrapped_tools: list[Tool] = [] + for tool in agent.tools: + if isinstance(tool, FunctionTool): + wrapped = _create_wrapper(state, tool) + wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=wrapped)) + else: + wrapped_tools.append(tool) + + wrapped_handoffs: list[Agent[Any] | Handoff[Any]] = [] + for handoff in agent.handoffs: + if isinstance(handoff, Agent): + wrapped_handoff = wrap_agent_tools(handoff, state) + wrapped_handoffs.append(wrapped_handoff) + elif isinstance(handoff, Handoff): + wrapped_handoffs.append(wrap_agent_handoff_tools(handoff, state)) + else: + raise TypeError(f"Unsupported handoff type: {type(handoff)}") + + return agent.clone(tools=wrapped_tools, handoffs=wrapped_handoffs) + + +def wrap_agent_handoff_tools( + handoff: Handoff[TContext], + state: State, +) -> Handoff[TContext]: + """ + Wrap the tools of a handoff to use the Restate error handling. + + Returns: + A new handoff with wrapped tools. + """ + + original_on_invoke_handoff = handoff.on_invoke_handoff + + async def wrapped(*args, **kwargs) -> Any: + agent = await original_on_invoke_handoff(*args, **kwargs) + return wrap_agent_tools(agent, state) + + return dataclasses.replace(handoff, on_invoke_handoff=wrapped) diff --git a/python/restate/ext/openai/models.py b/python/restate/ext/openai/models.py new file mode 100644 index 0000000..c48cf88 --- /dev/null +++ b/python/restate/ext/openai/models.py @@ -0,0 +1,79 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the optional OpenAI integration for Restate. +""" + +import dataclasses + +from agents import ( + Usage, +) +from agents.items import TResponseOutputItem +from agents.items import TResponseInputItem +from datetime import timedelta +from typing import Optional +from pydantic import BaseModel + +from restate.ext.turnstile import Turnstile + + +class State: + __slots__ = ("turnstile",) + + def __init__(self) -> None: + self.turnstile = Turnstile([]) + + +@dataclasses.dataclass +class LlmRetryOpts: + max_attempts: Optional[int] = 10 + """Max number of attempts (including the initial), before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + max_duration: Optional[timedelta] = None + """Max duration of retries, before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) + """Initial interval for the first retry attempt. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_interval: Optional[timedelta] = None + """Max interval between retries. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + The default is 10 seconds.""" + retry_interval_factor: Optional[float] = None + """Exponentiation factor to use when computing the next retry delay. + + If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" + + +# The OpenAI ModelResponse class is a dataclass with Pydantic fields. +# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. +class RestateModelResponse(BaseModel): + output: list[TResponseOutputItem] + """A list of outputs (messages, tool calls, etc) generated by the model""" + + usage: Usage + """The usage information for the response.""" + + response_id: str | None + """An ID for the response which can be used to refer to the response in subsequent calls to the + model. Not supported by all model providers. + If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can + be passed to `Runner.run`. + """ + + def to_input_items(self) -> list[TResponseInputItem]: + return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 572b819..6235125 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH # # This file is part of the Restate SDK for Python, # which is released under the MIT license. @@ -15,7 +15,6 @@ import dataclasses from agents import ( - Usage, Model, RunContextWrapper, AgentsException, @@ -24,66 +23,21 @@ RunResult, Agent, ModelBehaviorError, - ModelSettings, Runner, ) from agents.models.multi_provider import MultiProvider -from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse -from agents.memory.session import SessionABC +from agents.items import TResponseStreamEvent, ModelResponse from agents.items import TResponseInputItem -from datetime import timedelta -from typing import List, Any, AsyncIterator, Optional, cast -from pydantic import BaseModel +from typing import Any, AsyncIterator from restate.exceptions import SdkInternalBaseException +from restate.ext.turnstile import Turnstile from restate.extensions import current_context -from restate import RunOptions, ObjectContext, TerminalError +from restate import RunOptions, TerminalError - -@dataclasses.dataclass -class LlmRetryOpts: - max_attempts: Optional[int] = 10 - """Max number of attempts (including the initial), before giving up. - - When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" - max_duration: Optional[timedelta] = None - """Max duration of retries, before giving up. - - When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" - initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) - """Initial interval for the first retry attempt. - Retry interval will grow by a factor specified in `retry_interval_factor`. - - If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" - max_retry_interval: Optional[timedelta] = None - """Max interval between retries. - Retry interval will grow by a factor specified in `retry_interval_factor`. - - The default is 10 seconds.""" - retry_interval_factor: Optional[float] = None - """Exponentiation factor to use when computing the next retry delay. - - If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" - - -# The OpenAI ModelResponse class is a dataclass with Pydantic fields. -# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. -class RestateModelResponse(BaseModel): - output: list[TResponseOutputItem] - """A list of outputs (messages, tool calls, etc) generated by the model""" - - usage: Usage - """The usage information for the response.""" - - response_id: str | None - """An ID for the response which can be used to refer to the response in subsequent calls to the - model. Not supported by all model providers. - If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can - be passed to `Runner.run`. - """ - - def to_input_items(self) -> list[TResponseInputItem]: - return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore +from .functions import get_function_call_ids, wrap_agent_tools +from .models import LlmRetryOpts, RestateModelResponse, State +from .session import RestateSession class DurableModelCalls(MultiProvider): @@ -91,12 +45,14 @@ class DurableModelCalls(MultiProvider): A Restate model provider that wraps the OpenAI SDK's default MultiProvider. """ - def __init__(self, llm_retry_opts: LlmRetryOpts | None = None): + def __init__(self, state: State, llm_retry_opts: LlmRetryOpts | None = None): super().__init__() self.llm_retry_opts = llm_retry_opts + self.state = state def get_model(self, model_name: str | None) -> Model: - return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts) + model = super().get_model(model_name or None) + return RestateModelWrapper(model, self.state, self.llm_retry_opts) class RestateModelWrapper(Model): @@ -104,8 +60,9 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None): + def __init__(self, model: Model, state: State, llm_retry_opts: LlmRetryOpts | None = None): self.model = model + self.state = state self.model_name = "RestateModelWrapper" self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts() @@ -133,6 +90,10 @@ async def call_llm() -> RestateModelResponse: retry_interval_factor=self.llm_retry_opts.retry_interval_factor, ), ) + # collect function call IDs, to + ids = get_function_call_ids(result.output) + self.state.turnstile = Turnstile(ids) + # convert back to original ModelResponse return ModelResponse( output=result.output, @@ -144,47 +105,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.") -class RestateSession(SessionABC): - """Restate session implementation following the Session protocol.""" - - def __init__(self): - self._items: List[TResponseInputItem] | None = None - - def _ctx(self) -> ObjectContext: - return cast(ObjectContext, current_context()) - - async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: - """Retrieve conversation history for this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - if limit is not None: - return self._items[-limit:] - return self._items.copy() - - async def add_items(self, items: List[TResponseInputItem]) -> None: - """Store new items for this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - self._items.extend(items) - - async def pop_item(self) -> TResponseInputItem | None: - """Remove and return the most recent item from this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - if self._items: - return self._items.pop() - return None - - def flush(self) -> None: - """Flush the session items to the context.""" - self._ctx().set("items", self._items) - - async def clear_session(self) -> None: - """Clear all items for this session.""" - self._items = [] - self._ctx().clear("items") - - class AgentsTerminalException(AgentsException, TerminalError): """Exception that is both an AgentsException and a restate.TerminalError.""" @@ -255,24 +175,13 @@ async def run( The result from Runner.run """ + # execution state + state = State() + # Set persisting model calls llm_retry_opts = kwargs.pop("llm_retry_opts", None) run_config = kwargs.pop("run_config", RunConfig()) - run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) - - # Disable parallel tool calls - model_settings = run_config.model_settings - if model_settings is None: - model_settings = ModelSettings(parallel_tool_calls=False) - else: - model_settings = dataclasses.replace( - model_settings, - parallel_tool_calls=False, - ) - run_config = dataclasses.replace( - run_config, - model_settings=model_settings, - ) + run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(state, llm_retry_opts)) # Use Restate session if requested, otherwise use provided session session = kwargs.pop("session", None) @@ -281,9 +190,10 @@ async def run( raise TerminalError("When use_restate_session is True, session config cannot be provided.") session = RestateSession() + agent = wrap_agent_tools(starting_agent, state) try: result = await Runner.run( - starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs + starting_agent=agent, input=input, run_config=run_config, session=session, **kwargs ) finally: # Flush session items to Restate diff --git a/python/restate/ext/openai/session.py b/python/restate/ext/openai/session.py new file mode 100644 index 0000000..f4aefd7 --- /dev/null +++ b/python/restate/ext/openai/session.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the optional OpenAI integration for Restate. +""" + +from agents.memory.session import SessionABC +from agents.items import TResponseInputItem +from typing import List, cast + +from restate.extensions import current_context +from restate import ObjectContext + + +class RestateSession(SessionABC): + """Restate session implementation following the Session protocol.""" + + def __init__(self): + self._items: List[TResponseInputItem] | None = None + + def _ctx(self) -> ObjectContext: + return cast(ObjectContext, current_context()) + + async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: + """Retrieve conversation history for this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + if limit is not None: + return self._items[-limit:] + return self._items.copy() + + async def add_items(self, items: List[TResponseInputItem]) -> None: + """Store new items for this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + if self._items: + return self._items.pop() + return None + + def flush(self) -> None: + """Flush the session items to the context.""" + self._ctx().set("items", self._items) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + self._items = [] + self._ctx().clear("items") diff --git a/python/restate/ext/turnstile.py b/python/restate/ext/turnstile.py new file mode 100644 index 0000000..a2f6104 --- /dev/null +++ b/python/restate/ext/turnstile.py @@ -0,0 +1,31 @@ +from asyncio import Event + + +class Turnstile: + """A turnstile to manage ordered access based on IDs.""" + + def __init__(self, ids: list[str]): + # ordered mapping of id to next id in the sequence + # for example: + # {'id1': 'id2', 'id2': 'id3'} <-- id3 is the last. + # {} <-- no ids, no turns. + # {} <-- single id, no next turn. + # + self.turns = dict(zip(ids, ids[1:])) + # mapping of id to event that signals when that id's turn is allowed + self.events = {id: Event() for id in ids} + if ids: + # make sure that the first id can proceed immediately + event = self.events[ids[0]] + event.set() + + async def wait_for(self, id: str) -> None: + event = self.events[id] + await event.wait() + + def allow_next_after(self, id: str) -> None: + next_id = self.turns.get(id) + if next_id is None: + return + next_event = self.events[next_id] + next_event.set()