1212This module contains the optional OpenAI integration for Restate.
1313"""
1414
15- import asyncio
1615import dataclasses
17- import typing
1816
1917from agents import (
20- Tool ,
2118 Usage ,
2219 Model ,
2320 RunContextWrapper ,
2421 AgentsException ,
25- Runner as OpenAIRunner ,
2622 RunConfig ,
2723 TContext ,
2824 RunResult ,
2925 Agent ,
3026 ModelBehaviorError ,
27+ ModelSettings ,
28+ Runner ,
3129)
32-
3330from agents .models .multi_provider import MultiProvider
3431from agents .items import TResponseStreamEvent , TResponseOutputItem , ModelResponse
3532from agents .memory .session import SessionABC
3633from agents .items import TResponseInputItem
37- from typing import List , Any
38- from typing import AsyncIterator
39-
40- from agents .tool import FunctionTool
41- from agents .tool_context import ToolContext
34+ from datetime import timedelta
35+ from typing import List , Any , AsyncIterator , Optional , cast
4236from pydantic import BaseModel
37+
4338from restate .exceptions import SdkInternalBaseException
4439from restate .extensions import current_context
45-
4640from restate import RunOptions , ObjectContext , TerminalError
4741
4842
43+ @dataclasses .dataclass
44+ class LlmRetryOpts :
45+ max_attempts : Optional [int ] = 10
46+ """Max number of attempts (including the initial), before giving up.
47+
48+ When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
49+ max_duration : Optional [timedelta ] = None
50+ """Max duration of retries, before giving up.
51+
52+ When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
53+ initial_retry_interval : Optional [timedelta ] = timedelta (seconds = 1 )
54+ """Initial interval for the first retry attempt.
55+ Retry interval will grow by a factor specified in `retry_interval_factor`.
56+
57+ If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
58+ max_retry_interval : Optional [timedelta ] = None
59+ """Max interval between retries.
60+ Retry interval will grow by a factor specified in `retry_interval_factor`.
61+
62+ The default is 10 seconds."""
63+ retry_interval_factor : Optional [float ] = None
64+ """Exponentiation factor to use when computing the next retry delay.
65+
66+ If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
67+
68+
4969# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5070# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
5171class RestateModelResponse (BaseModel ):
@@ -71,23 +91,23 @@ class DurableModelCalls(MultiProvider):
7191 A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7292 """
7393
74- def __init__ (self , max_retries : int | None = 3 ):
94+ def __init__ (self , llm_retry_opts : LlmRetryOpts | None = None ):
7595 super ().__init__ ()
76- self .max_retries = max_retries
96+ self .llm_retry_opts = llm_retry_opts
7797
7898 def get_model (self , model_name : str | None ) -> Model :
79- return RestateModelWrapper (super ().get_model (model_name or None ), self .max_retries )
99+ return RestateModelWrapper (super ().get_model (model_name or None ), self .llm_retry_opts )
80100
81101
82102class RestateModelWrapper (Model ):
83103 """
84104 A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
85105 """
86106
87- def __init__ (self , model : Model , max_retries : int | None = 3 ):
107+ def __init__ (self , model : Model , llm_retry_opts : LlmRetryOpts | None = None ):
88108 self .model = model
89109 self .model_name = "RestateModelWrapper"
90- self .max_retries = max_retries
110+ self .llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts ()
91111
92112 async def get_response (self , * args , ** kwargs ) -> ModelResponse :
93113 async def call_llm () -> RestateModelResponse :
@@ -102,7 +122,17 @@ async def call_llm() -> RestateModelResponse:
102122 ctx = current_context ()
103123 if ctx is None :
104124 raise RuntimeError ("No current Restate context found, make sure to run inside a Restate handler" )
105- result = await ctx .run_typed ("call LLM" , call_llm , RunOptions (max_attempts = self .max_retries ))
125+ result = await ctx .run_typed (
126+ "call LLM" ,
127+ call_llm ,
128+ RunOptions (
129+ max_attempts = self .llm_retry_opts .max_attempts ,
130+ max_duration = self .llm_retry_opts .max_duration ,
131+ initial_retry_interval = self .llm_retry_opts .initial_retry_interval ,
132+ max_retry_interval = self .llm_retry_opts .max_retry_interval ,
133+ retry_interval_factor = self .llm_retry_opts .retry_interval_factor ,
134+ ),
135+ )
106136 # convert back to original ModelResponse
107137 return ModelResponse (
108138 output = result .output ,
@@ -117,33 +147,41 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
117147class RestateSession (SessionABC ):
118148 """Restate session implementation following the Session protocol."""
119149
150+ def __init__ (self ):
151+ self ._items : List [TResponseInputItem ] | None = None
152+
120153 def _ctx (self ) -> ObjectContext :
121- return typing . cast (ObjectContext , current_context ())
154+ return cast (ObjectContext , current_context ())
122155
123156 async def get_items (self , limit : int | None = None ) -> List [TResponseInputItem ]:
124157 """Retrieve conversation history for this session."""
125- current_items = await self ._ctx ().get ("items" , type_hint = List [TResponseInputItem ]) or []
158+ if self ._items is None :
159+ self ._items = await self ._ctx ().get ("items" ) or []
126160 if limit is not None :
127- return current_items [- limit :]
128- return current_items
161+ return self . _items [- limit :]
162+ return self . _items . copy ()
129163
130164 async def add_items (self , items : List [TResponseInputItem ]) -> None :
131165 """Store new items for this session."""
132- # Your implementation here
133- current_items = await self .get_items ( ) or []
134- self ._ctx (). set ( "items" , current_items + items )
166+ if self . _items is None :
167+ self . _items = await self ._ctx (). get ( "items" ) or []
168+ self ._items . extend ( items )
135169
136170 async def pop_item (self ) -> TResponseInputItem | None :
137171 """Remove and return the most recent item from this session."""
138- current_items = await self .get_items () or []
139- if current_items :
140- item = current_items .pop ()
141- self ._ctx ().set ("items" , current_items )
142- return item
172+ if self ._items is None :
173+ self ._items = await self ._ctx ().get ("items" ) or []
174+ if self ._items :
175+ return self ._items .pop ()
143176 return None
144177
178+ def flush (self ) -> None :
179+ """Flush the session items to the context."""
180+ self ._ctx ().set ("items" , self ._items )
181+
145182 async def clear_session (self ) -> None :
146183 """Clear all items for this session."""
184+ self ._items = []
147185 self ._ctx ().clear ("items" )
148186
149187
@@ -189,7 +227,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio
189227 raise error
190228
191229
192- class Runner :
230+ class DurableRunner :
193231 """
194232 A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
195233
@@ -201,83 +239,55 @@ class Runner:
201239 @staticmethod
202240 async def run (
203241 starting_agent : Agent [TContext ],
204- disable_tool_autowrapping : bool = False ,
205- * args : typing . Any ,
206- run_config : RunConfig | None = None ,
242+ input : str | list [ TResponseInputItem ] ,
243+ * ,
244+ use_restate_session : bool = False ,
207245 ** kwargs ,
208246 ) -> RunResult :
209247 """
210248 Run an agent with automatic Restate configuration.
211249
250+ Args:
251+ use_restate_session: If True, creates a RestateSession for conversation persistence.
252+ Requires running within a Restate Virtual Object context.
253+
212254 Returns:
213255 The result from Runner.run
214256 """
215257
216- current_run_config = run_config or RunConfig ()
217- new_run_config = dataclasses .replace (
218- current_run_config ,
219- model_provider = DurableModelCalls (),
220- )
221- restate_agent = sequentialize_and_wrap_tools (starting_agent , disable_tool_autowrapping )
222- return await OpenAIRunner .run (restate_agent , * args , run_config = new_run_config , ** kwargs )
223-
224-
225- def sequentialize_and_wrap_tools (
226- agent : Agent [TContext ],
227- disable_tool_autowrapping : bool ,
228- ) -> Agent [TContext ]:
229- """
230- Wrap the tools of an agent to use the Restate error handling.
231-
232- Returns:
233- A new agent with wrapped tools.
234- """
258+ # Set persisting model calls
259+ llm_retry_opts = kwargs .pop ("llm_retry_opts" , None )
260+ run_config = kwargs .pop ("run_config" , RunConfig ())
261+ run_config = dataclasses .replace (run_config , model_provider = DurableModelCalls (llm_retry_opts ))
235262
236- # Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237- # This lock only affects tools for this agent; handoff agents are wrapped recursively.
238- sequential_tools_lock = asyncio .Lock ()
239- wrapped_tools : list [Tool ] = []
240- for tool in agent .tools :
241- if isinstance (tool , FunctionTool ):
242-
243- def create_wrapper (captured_tool ):
244- async def on_invoke_tool_wrapper (tool_context : ToolContext [Any ], tool_input : Any ) -> Any :
245- await sequential_tools_lock .acquire ()
246-
247- async def invoke ():
248- result = await captured_tool .on_invoke_tool (tool_context , tool_input )
249- # Ensure Pydantic objects are serialized to dict for LLM compatibility
250- if hasattr (result , "model_dump" ):
251- return result .model_dump ()
252- elif hasattr (result , "dict" ):
253- return result .dict ()
254- return result
255-
256- try :
257- if disable_tool_autowrapping :
258- return await invoke ()
259-
260- ctx = current_context ()
261- if ctx is None :
262- raise RuntimeError (
263- "No current Restate context found, make sure to run inside a Restate handler"
264- )
265- return await ctx .run_typed (captured_tool .name , invoke )
266- finally :
267- sequential_tools_lock .release ()
268-
269- return on_invoke_tool_wrapper
270-
271- wrapped_tools .append (dataclasses .replace (tool , on_invoke_tool = create_wrapper (tool )))
263+ # Disable parallel tool calls
264+ model_settings = run_config .model_settings
265+ if model_settings is None :
266+ model_settings = ModelSettings (parallel_tool_calls = False )
272267 else :
273- wrapped_tools .append (tool )
268+ model_settings = dataclasses .replace (
269+ model_settings ,
270+ parallel_tool_calls = False ,
271+ )
272+ run_config = dataclasses .replace (
273+ run_config ,
274+ model_settings = model_settings ,
275+ )
276+
277+ # Use Restate session if requested, otherwise use provided session
278+ session = kwargs .pop ("session" , None )
279+ if use_restate_session :
280+ if session is not None :
281+ raise TerminalError ("When use_restate_session is True, session config cannot be provided." )
282+ session = RestateSession ()
274283
275- handoffs_with_wrapped_tools = []
276- for handoff in agent .handoffs :
277- # recursively wrap tools in handoff agents
278- handoffs_with_wrapped_tools .append (sequentialize_and_wrap_tools (handoff , disable_tool_autowrapping )) # type: ignore
284+ try :
285+ result = await Runner .run (
286+ starting_agent = starting_agent , input = input , run_config = run_config , session = session , ** kwargs
287+ )
288+ finally :
289+ # Flush session items to Restate
290+ if session is not None and isinstance (session , RestateSession ):
291+ session .flush ()
279292
280- return agent .clone (
281- tools = wrapped_tools ,
282- handoffs = handoffs_with_wrapped_tools ,
283- )
293+ return result
0 commit comments