diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index d9d9c82..65bdbc2 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -81,7 +81,9 @@ async def release_task(): await event.wait() finally: self._models.pop(id, None) - self._turnstiles.pop(id, None) + maybe_turnstile = self._turnstiles.pop(id, None) + if maybe_turnstile is not None: + maybe_turnstile.cancel_all() _ = asyncio.create_task(release_task()) return None @@ -161,7 +163,7 @@ async def on_tool_error_callback( assert turnstile is not None, "Turnstile not found for tool invocation." id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." - turnstile.allow_next_after(id) + turnstile.cancel_all_after(id) return None async def close(self): diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index a13ac9f..94e6d0b 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -17,8 +17,9 @@ from restate import ObjectContext, Context from restate.server_context import current_context -from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors +from .runner_wrapper import DurableRunner from .models import LlmRetryOpts +from .functions import propagate_cancellation, raise_terminal_errors, durable_function_tool def restate_object_context() -> ObjectContext: @@ -42,6 +43,7 @@ def restate_context() -> Context: "LlmRetryOpts", "restate_object_context", "restate_context", - "continue_on_terminal_errors", "raise_terminal_errors", + "propagate_cancellation", + "durable_function_tool", ] diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 6fee0ce..621f5c6 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -9,20 +9,132 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # -from typing import List, Any +from typing import List, Any, overload, Callable, Union, TypeVar, Awaitable import dataclasses from agents import ( Handoff, TContext, Agent, + RunContextWrapper, + ModelBehaviorError, + AgentBase, ) -from agents.tool import FunctionTool, Tool +from agents.function_schema import DocstringStyle + +from agents.tool import ( + FunctionTool, + Tool, + ToolFunction, + ToolErrorFunction, + function_tool as oai_function_tool, + default_tool_error_function, +) from agents.tool_context import ToolContext from agents.items import TResponseOutputItem -from .models import State +from restate import TerminalError + +from .models import State, AgentsTerminalException + +T = TypeVar("T") + +MaybeAwaitable = Union[Awaitable[T], T] + + +def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: + """A custom function to provide a user-friendly error message.""" + # Raise terminal errors and cancellations + if isinstance(error, TerminalError): + # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError + # so we create a new exception that inherits from both + raise AgentsTerminalException(error.message) + + if isinstance(error, ModelBehaviorError): + return f"An error occurred while calling the tool: {str(error)}" + + raise error + + +def propagate_cancellation(failure_error_function: ToolErrorFunction | None = None) -> ToolErrorFunction: + _fn = failure_error_function if failure_error_function is not None else default_tool_error_function + + def inner(context: RunContextWrapper[Any], error: Exception): + """Raise cancellations as exceptions.""" + if isinstance(error, TerminalError): + if error.status_code == 409: + raise error from None + return _fn(context, error) + + return inner + + +@overload +def durable_function_tool( + func: ToolFunction[...], + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> FunctionTool: + """Overload for usage as @function_tool (no parentheses).""" + ... + + +@overload +def durable_function_tool( + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> Callable[[ToolFunction[...]], FunctionTool]: + """Overload for usage as @function_tool(...).""" + ... + + +def durable_function_tool( + func: ToolFunction[...] | None = None, + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = raise_terminal_errors, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: + failure_fn = propagate_cancellation(failure_error_function) + + if callable(func): + return oai_function_tool( + func=func, + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_fn, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) + else: + return oai_function_tool( + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_fn, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: @@ -35,11 +147,21 @@ def _create_wrapper(state, captured_tool): async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: turnstile = state.turnstile call_id = tool_context.tool_call_id + # wait for our turn + await turnstile.wait_for(call_id) try: - await turnstile.wait_for(call_id) - return await captured_tool.on_invoke_tool(tool_context, tool_input) - finally: + # invoke the original tool + res = await captured_tool.on_invoke_tool(tool_context, tool_input) + # allow the next tool to proceed turnstile.allow_next_after(call_id) + return res + except BaseException as ex: + # if there was an error, it will be propagated up, towards the handler + # but we need to make sure that all subsequent tools will not execute + # as they might interact with the restate context. + turnstile.cancel_all_after(call_id) + # re-raise the exception + raise ex from None return on_invoke_tool_wrapper diff --git a/python/restate/ext/openai/models.py b/python/restate/ext/openai/models.py index c48cf88..650c03a 100644 --- a/python/restate/ext/openai/models.py +++ b/python/restate/ext/openai/models.py @@ -16,6 +16,7 @@ from agents import ( Usage, + AgentsException, ) from agents.items import TResponseOutputItem from agents.items import TResponseInputItem @@ -24,6 +25,7 @@ from pydantic import BaseModel from restate.ext.turnstile import Turnstile +from restate import TerminalError class State: @@ -77,3 +79,10 @@ class RestateModelResponse(BaseModel): def to_input_items(self) -> list[TResponseInputItem]: return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore + + +class AgentsTerminalException(AgentsException, TerminalError): + """Exception that is both an AgentsException and a restate.TerminalError.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 6235125..d93f6db 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -16,21 +16,17 @@ from agents import ( Model, - RunContextWrapper, - AgentsException, RunConfig, TContext, RunResult, Agent, - ModelBehaviorError, Runner, ) from agents.models.multi_provider import MultiProvider from agents.items import TResponseStreamEvent, ModelResponse from agents.items import TResponseInputItem -from typing import Any, AsyncIterator +from typing import AsyncIterator -from restate.exceptions import SdkInternalBaseException from restate.ext.turnstile import Turnstile from restate.extensions import current_context from restate import RunOptions, TerminalError @@ -105,48 +101,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.") -class AgentsTerminalException(AgentsException, TerminalError): - """Exception that is both an AgentsException and a restate.TerminalError.""" - - def __init__(self, *args: object) -> None: - super().__init__(*args) - - -class AgentsSuspension(AgentsException, SdkInternalBaseException): - """Exception that is both an AgentsException and a restate SdkInternalBaseException.""" - - def __init__(self, *args: object) -> None: - super().__init__(*args) - - -def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: - """A custom function to provide a user-friendly error message.""" - # Raise terminal errors and cancellations - if isinstance(error, TerminalError): - # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError - # so we create a new exception that inherits from both - raise AgentsTerminalException(error.message) - - if isinstance(error, ModelBehaviorError): - return f"An error occurred while calling the tool: {str(error)}" - - raise error - - -def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: - """A custom function to provide a user-friendly error message.""" - # Raise terminal errors and cancellations - if isinstance(error, TerminalError): - # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError - # so we create a new exception that inherits from both - return f"An error occurred while running the tool: {str(error)}" - - if isinstance(error, ModelBehaviorError): - return f"An error occurred while calling the tool: {str(error)}" - - raise error - - class DurableRunner: """ A wrapper around Runner.run that automatically configures RunConfig for Restate contexts. diff --git a/python/restate/ext/turnstile.py b/python/restate/ext/turnstile.py index a2f6104..08981c6 100644 --- a/python/restate/ext/turnstile.py +++ b/python/restate/ext/turnstile.py @@ -1,4 +1,5 @@ from asyncio import Event +from restate.exceptions import SdkInternalException class Turnstile: @@ -14,6 +15,7 @@ def __init__(self, ids: list[str]): self.turns = dict(zip(ids, ids[1:])) # mapping of id to event that signals when that id's turn is allowed self.events = {id: Event() for id in ids} + self.canceled = False if ids: # make sure that the first id can proceed immediately event = self.events[ids[0]] @@ -22,6 +24,16 @@ def __init__(self, ids: list[str]): async def wait_for(self, id: str) -> None: event = self.events[id] await event.wait() + if self.canceled: + raise SdkInternalException() from None + + def cancel_all_after(self, id: str) -> None: + self.canceled = True + next_id = self.turns.get(id) + while next_id is not None: + next_event = self.events[next_id] + next_event.set() + next_id = self.turns.get(next_id) def allow_next_after(self, id: str) -> None: next_id = self.turns.get(id) @@ -29,3 +41,8 @@ def allow_next_after(self, id: str) -> None: return next_event = self.events[next_id] next_event.set() + + def cancel_all(self) -> None: + self.canceled = True + for event in self.events.values(): + event.set()