diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index e2c431c..969d42f 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -12,11 +12,34 @@ This module contains the optional OpenAI integration for Restate. """ -from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors +import typing + +from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts +from restate import ObjectContext, Context +from restate.server_context import current_context + + +def restate_object_context() -> ObjectContext: + """Get the current Restate ObjectContext.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return typing.cast(ObjectContext, ctx) + + +def restate_context() -> Context: + """Get the current Restate Context.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return ctx + __all__ = [ - "DurableModelCalls", + "DurableRunner", + "LlmRetryOpts", + "restate_object_context", + "restate_context", "continue_on_terminal_errors", "raise_terminal_errors", - "Runner", ] diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index c9e9f5a..572b819 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -12,40 +12,60 @@ This module contains the optional OpenAI integration for Restate. """ -import asyncio import dataclasses -import typing from agents import ( - Tool, Usage, Model, RunContextWrapper, AgentsException, - Runner as OpenAIRunner, RunConfig, TContext, RunResult, Agent, ModelBehaviorError, + ModelSettings, + Runner, ) - from agents.models.multi_provider import MultiProvider from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse from agents.memory.session import SessionABC from agents.items import TResponseInputItem -from typing import List, Any -from typing import AsyncIterator - -from agents.tool import FunctionTool -from agents.tool_context import ToolContext +from datetime import timedelta +from typing import List, Any, AsyncIterator, Optional, cast from pydantic import BaseModel + from restate.exceptions import SdkInternalBaseException from restate.extensions import current_context - from restate import RunOptions, ObjectContext, TerminalError +@dataclasses.dataclass +class LlmRetryOpts: + max_attempts: Optional[int] = 10 + """Max number of attempts (including the initial), before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + max_duration: Optional[timedelta] = None + """Max duration of retries, before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) + """Initial interval for the first retry attempt. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_interval: Optional[timedelta] = None + """Max interval between retries. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + The default is 10 seconds.""" + retry_interval_factor: Optional[float] = None + """Exponentiation factor to use when computing the next retry delay. + + If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" + + # The OpenAI ModelResponse class is a dataclass with Pydantic fields. # The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. class RestateModelResponse(BaseModel): @@ -71,12 +91,12 @@ class DurableModelCalls(MultiProvider): A Restate model provider that wraps the OpenAI SDK's default MultiProvider. """ - def __init__(self, max_retries: int | None = 3): + def __init__(self, llm_retry_opts: LlmRetryOpts | None = None): super().__init__() - self.max_retries = max_retries + self.llm_retry_opts = llm_retry_opts def get_model(self, model_name: str | None) -> Model: - return RestateModelWrapper(super().get_model(model_name or None), self.max_retries) + return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts) class RestateModelWrapper(Model): @@ -84,10 +104,10 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, max_retries: int | None = 3): + def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None): self.model = model self.model_name = "RestateModelWrapper" - self.max_retries = max_retries + self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts() async def get_response(self, *args, **kwargs) -> ModelResponse: async def call_llm() -> RestateModelResponse: @@ -102,7 +122,17 @@ async def call_llm() -> RestateModelResponse: ctx = current_context() if ctx is None: raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler") - result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries)) + result = await ctx.run_typed( + "call LLM", + call_llm, + RunOptions( + max_attempts=self.llm_retry_opts.max_attempts, + max_duration=self.llm_retry_opts.max_duration, + initial_retry_interval=self.llm_retry_opts.initial_retry_interval, + max_retry_interval=self.llm_retry_opts.max_retry_interval, + retry_interval_factor=self.llm_retry_opts.retry_interval_factor, + ), + ) # convert back to original ModelResponse return ModelResponse( output=result.output, @@ -117,33 +147,41 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent class RestateSession(SessionABC): """Restate session implementation following the Session protocol.""" + def __init__(self): + self._items: List[TResponseInputItem] | None = None + def _ctx(self) -> ObjectContext: - return typing.cast(ObjectContext, current_context()) + return cast(ObjectContext, current_context()) async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: """Retrieve conversation history for this session.""" - current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or [] + if self._items is None: + self._items = await self._ctx().get("items") or [] if limit is not None: - return current_items[-limit:] - return current_items + return self._items[-limit:] + return self._items.copy() async def add_items(self, items: List[TResponseInputItem]) -> None: """Store new items for this session.""" - # Your implementation here - current_items = await self.get_items() or [] - self._ctx().set("items", current_items + items) + if self._items is None: + self._items = await self._ctx().get("items") or [] + self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from this session.""" - current_items = await self.get_items() or [] - if current_items: - item = current_items.pop() - self._ctx().set("items", current_items) - return item + if self._items is None: + self._items = await self._ctx().get("items") or [] + if self._items: + return self._items.pop() return None + def flush(self) -> None: + """Flush the session items to the context.""" + self._ctx().set("items", self._items) + async def clear_session(self) -> None: """Clear all items for this session.""" + self._items = [] self._ctx().clear("items") @@ -189,7 +227,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio raise error -class Runner: +class DurableRunner: """ A wrapper around Runner.run that automatically configures RunConfig for Restate contexts. @@ -201,83 +239,55 @@ class Runner: @staticmethod async def run( starting_agent: Agent[TContext], - disable_tool_autowrapping: bool = False, - *args: typing.Any, - run_config: RunConfig | None = None, + input: str | list[TResponseInputItem], + *, + use_restate_session: bool = False, **kwargs, ) -> RunResult: """ Run an agent with automatic Restate configuration. + Args: + use_restate_session: If True, creates a RestateSession for conversation persistence. + Requires running within a Restate Virtual Object context. + Returns: The result from Runner.run """ - current_run_config = run_config or RunConfig() - new_run_config = dataclasses.replace( - current_run_config, - model_provider=DurableModelCalls(), - ) - restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping) - return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs) - - -def sequentialize_and_wrap_tools( - agent: Agent[TContext], - disable_tool_autowrapping: bool, -) -> Agent[TContext]: - """ - Wrap the tools of an agent to use the Restate error handling. - - Returns: - A new agent with wrapped tools. - """ + # Set persisting model calls + llm_retry_opts = kwargs.pop("llm_retry_opts", None) + run_config = kwargs.pop("run_config", RunConfig()) + run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) - # Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution. - # This lock only affects tools for this agent; handoff agents are wrapped recursively. - sequential_tools_lock = asyncio.Lock() - wrapped_tools: list[Tool] = [] - for tool in agent.tools: - if isinstance(tool, FunctionTool): - - def create_wrapper(captured_tool): - async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: - await sequential_tools_lock.acquire() - - async def invoke(): - result = await captured_tool.on_invoke_tool(tool_context, tool_input) - # Ensure Pydantic objects are serialized to dict for LLM compatibility - if hasattr(result, "model_dump"): - return result.model_dump() - elif hasattr(result, "dict"): - return result.dict() - return result - - try: - if disable_tool_autowrapping: - return await invoke() - - ctx = current_context() - if ctx is None: - raise RuntimeError( - "No current Restate context found, make sure to run inside a Restate handler" - ) - return await ctx.run_typed(captured_tool.name, invoke) - finally: - sequential_tools_lock.release() - - return on_invoke_tool_wrapper - - wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool))) + # Disable parallel tool calls + model_settings = run_config.model_settings + if model_settings is None: + model_settings = ModelSettings(parallel_tool_calls=False) else: - wrapped_tools.append(tool) + model_settings = dataclasses.replace( + model_settings, + parallel_tool_calls=False, + ) + run_config = dataclasses.replace( + run_config, + model_settings=model_settings, + ) + + # Use Restate session if requested, otherwise use provided session + session = kwargs.pop("session", None) + if use_restate_session: + if session is not None: + raise TerminalError("When use_restate_session is True, session config cannot be provided.") + session = RestateSession() - handoffs_with_wrapped_tools = [] - for handoff in agent.handoffs: - # recursively wrap tools in handoff agents - handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore + try: + result = await Runner.run( + starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs + ) + finally: + # Flush session items to Restate + if session is not None and isinstance(session, RestateSession): + session.flush() - return agent.clone( - tools=wrapped_tools, - handoffs=handoffs_with_wrapped_tools, - ) + return result