Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/restate/ext/adk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ async def release_task():
await event.wait()
finally:
self._models.pop(id, None)
self._turnstiles.pop(id, None)
maybe_turnstile = self._turnstiles.pop(id, None)
if maybe_turnstile is not None:
maybe_turnstile.cancel_all()

_ = asyncio.create_task(release_task())
return None
Expand Down Expand Up @@ -161,7 +163,7 @@ async def on_tool_error_callback(
assert turnstile is not None, "Turnstile not found for tool invocation."
id = tool_context.function_call_id
assert id is not None, "Function call ID is required for tool invocation."
turnstile.allow_next_after(id)
turnstile.cancel_all_after(id)
return None

async def close(self):
Expand Down
6 changes: 4 additions & 2 deletions python/restate/ext/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from restate import ObjectContext, Context
from restate.server_context import current_context

from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors
from .runner_wrapper import DurableRunner
from .models import LlmRetryOpts
from .functions import propagate_cancellation, raise_terminal_errors, durable_function_tool


def restate_object_context() -> ObjectContext:
Expand All @@ -42,6 +43,7 @@ def restate_context() -> Context:
"LlmRetryOpts",
"restate_object_context",
"restate_context",
"continue_on_terminal_errors",
"raise_terminal_errors",
"propagate_cancellation",
"durable_function_tool",
]
134 changes: 128 additions & 6 deletions python/restate/ext/openai/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,132 @@
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from typing import List, Any
from typing import List, Any, overload, Callable, Union, TypeVar, Awaitable
import dataclasses

from agents import (
Handoff,
TContext,
Agent,
RunContextWrapper,
ModelBehaviorError,
AgentBase,
)

from agents.tool import FunctionTool, Tool
from agents.function_schema import DocstringStyle

from agents.tool import (
FunctionTool,
Tool,
ToolFunction,
ToolErrorFunction,
function_tool as oai_function_tool,
default_tool_error_function,
)
from agents.tool_context import ToolContext
from agents.items import TResponseOutputItem

from .models import State
from restate import TerminalError

from .models import State, AgentsTerminalException

T = TypeVar("T")

MaybeAwaitable = Union[Awaitable[T], T]


def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
"""A custom function to provide a user-friendly error message."""
# Raise terminal errors and cancellations
if isinstance(error, TerminalError):
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
# so we create a new exception that inherits from both
raise AgentsTerminalException(error.message)

if isinstance(error, ModelBehaviorError):
return f"An error occurred while calling the tool: {str(error)}"

raise error


def propagate_cancellation(failure_error_function: ToolErrorFunction | None = None) -> ToolErrorFunction:
_fn = failure_error_function if failure_error_function is not None else default_tool_error_function

def inner(context: RunContextWrapper[Any], error: Exception):
"""Raise cancellations as exceptions."""
if isinstance(error, TerminalError):
if error.status_code == 409:
raise error from None
return _fn(context, error)

return inner


@overload
def durable_function_tool(
func: ToolFunction[...],
*,
name_override: str | None = None,
description_override: str | None = None,
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...


@overload
def durable_function_tool(
*,
name_override: str | None = None,
description_override: str | None = None,
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...


def durable_function_tool(
func: ToolFunction[...] | None = None,
*,
name_override: str | None = None,
description_override: str | None = None,
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = raise_terminal_errors,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
failure_fn = propagate_cancellation(failure_error_function)

if callable(func):
return oai_function_tool(
func=func,
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
failure_error_function=failure_fn,
strict_mode=strict_mode,
is_enabled=is_enabled,
)
else:
return oai_function_tool(
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
failure_error_function=failure_fn,
strict_mode=strict_mode,
is_enabled=is_enabled,
)


def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]:
Expand All @@ -35,11 +147,21 @@ def _create_wrapper(state, captured_tool):
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
turnstile = state.turnstile
call_id = tool_context.tool_call_id
# wait for our turn
await turnstile.wait_for(call_id)
try:
await turnstile.wait_for(call_id)
return await captured_tool.on_invoke_tool(tool_context, tool_input)
finally:
# invoke the original tool
res = await captured_tool.on_invoke_tool(tool_context, tool_input)
# allow the next tool to proceed
turnstile.allow_next_after(call_id)
return res
except BaseException as ex:
# if there was an error, it will be propagated up, towards the handler
# but we need to make sure that all subsequent tools will not execute
# as they might interact with the restate context.
turnstile.cancel_all_after(call_id)
# re-raise the exception
raise ex from None

return on_invoke_tool_wrapper

Expand Down
9 changes: 9 additions & 0 deletions python/restate/ext/openai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from agents import (
Usage,
AgentsException,
)
from agents.items import TResponseOutputItem
from agents.items import TResponseInputItem
Expand All @@ -24,6 +25,7 @@
from pydantic import BaseModel

from restate.ext.turnstile import Turnstile
from restate import TerminalError


class State:
Expand Down Expand Up @@ -77,3 +79,10 @@ class RestateModelResponse(BaseModel):

def to_input_items(self) -> list[TResponseInputItem]:
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore


class AgentsTerminalException(AgentsException, TerminalError):
"""Exception that is both an AgentsException and a restate.TerminalError."""

def __init__(self, *args: object) -> None:
super().__init__(*args)
48 changes: 1 addition & 47 deletions python/restate/ext/openai/runner_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@

from agents import (
Model,
RunContextWrapper,
AgentsException,
RunConfig,
TContext,
RunResult,
Agent,
ModelBehaviorError,
Runner,
)
from agents.models.multi_provider import MultiProvider
from agents.items import TResponseStreamEvent, ModelResponse
from agents.items import TResponseInputItem
from typing import Any, AsyncIterator
from typing import AsyncIterator

from restate.exceptions import SdkInternalBaseException
from restate.ext.turnstile import Turnstile
from restate.extensions import current_context
from restate import RunOptions, TerminalError
Expand Down Expand Up @@ -105,48 +101,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.")


class AgentsTerminalException(AgentsException, TerminalError):
"""Exception that is both an AgentsException and a restate.TerminalError."""

def __init__(self, *args: object) -> None:
super().__init__(*args)


class AgentsSuspension(AgentsException, SdkInternalBaseException):
"""Exception that is both an AgentsException and a restate SdkInternalBaseException."""

def __init__(self, *args: object) -> None:
super().__init__(*args)


def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
"""A custom function to provide a user-friendly error message."""
# Raise terminal errors and cancellations
if isinstance(error, TerminalError):
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
# so we create a new exception that inherits from both
raise AgentsTerminalException(error.message)

if isinstance(error, ModelBehaviorError):
return f"An error occurred while calling the tool: {str(error)}"

raise error


def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
"""A custom function to provide a user-friendly error message."""
# Raise terminal errors and cancellations
if isinstance(error, TerminalError):
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
# so we create a new exception that inherits from both
return f"An error occurred while running the tool: {str(error)}"

if isinstance(error, ModelBehaviorError):
return f"An error occurred while calling the tool: {str(error)}"

raise error


class DurableRunner:
"""
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
Expand Down
17 changes: 17 additions & 0 deletions python/restate/ext/turnstile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import Event
from restate.exceptions import SdkInternalException


class Turnstile:
Expand All @@ -14,6 +15,7 @@ def __init__(self, ids: list[str]):
self.turns = dict(zip(ids, ids[1:]))
# mapping of id to event that signals when that id's turn is allowed
self.events = {id: Event() for id in ids}
self.canceled = False
if ids:
# make sure that the first id can proceed immediately
event = self.events[ids[0]]
Expand All @@ -22,10 +24,25 @@ def __init__(self, ids: list[str]):
async def wait_for(self, id: str) -> None:
event = self.events[id]
await event.wait()
if self.canceled:
raise SdkInternalException() from None

def cancel_all_after(self, id: str) -> None:
self.canceled = True
next_id = self.turns.get(id)
while next_id is not None:
next_event = self.events[next_id]
next_event.set()
next_id = self.turns.get(next_id)

def allow_next_after(self, id: str) -> None:
next_id = self.turns.get(id)
if next_id is None:
return
next_event = self.events[next_id]
next_event.set()

def cancel_all(self) -> None:
self.canceled = True
for event in self.events.values():
event.set()