2525 Agent ,
2626 ModelBehaviorError ,
2727 ModelSettings ,
28+ Runner ,
2829)
2930from agents .models .multi_provider import MultiProvider
3031from agents .items import TResponseStreamEvent , TResponseOutputItem , ModelResponse
3132from agents .memory .session import SessionABC
3233from agents .items import TResponseInputItem
33- from agents .run import (
34- AgentRunner ,
35- DEFAULT_AGENT_RUNNER ,
36- )
3734from datetime import timedelta
3835from typing import List , Any , AsyncIterator , Optional , cast
3936from pydantic import BaseModel
4239from restate .extensions import current_context
4340from restate import RunOptions , ObjectContext , TerminalError
4441
42+
4543@dataclasses .dataclass
4644class LlmRetryOpts :
4745 max_attempts : Optional [int ] = 10
@@ -106,10 +104,10 @@ class RestateModelWrapper(Model):
106104 A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
107105 """
108106
109- def __init__ (self , model : Model , llm_retry_opts : LlmRetryOpts | None = LlmRetryOpts () ):
107+ def __init__ (self , model : Model , llm_retry_opts : LlmRetryOpts | None = None ):
110108 self .model = model
111109 self .model_name = "RestateModelWrapper"
112- self .llm_retry_opts = llm_retry_opts
110+ self .llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts ()
113111
114112 async def get_response (self , * args , ** kwargs ) -> ModelResponse :
115113 async def call_llm () -> RestateModelResponse :
@@ -155,26 +153,24 @@ def __init__(self):
155153 def _ctx (self ) -> ObjectContext :
156154 return cast (ObjectContext , current_context ())
157155
158- async def _load_items_if_needed (self ) -> None :
159- """Load items from context if not already loaded."""
160- if self ._items is None :
161- self ._items = await self ._ctx ().get ("items" , type_hint = List [TResponseInputItem ]) or []
162-
163156 async def get_items (self , limit : int | None = None ) -> List [TResponseInputItem ]:
164157 """Retrieve conversation history for this session."""
165- await self ._load_items_if_needed ()
158+ if self ._items is None :
159+ self ._items = await self ._ctx ().get ("items" ) or []
166160 if limit is not None :
167161 return self ._items [- limit :]
168162 return self ._items .copy ()
169163
170164 async def add_items (self , items : List [TResponseInputItem ]) -> None :
171165 """Store new items for this session."""
172- await self ._load_items_if_needed ()
166+ if self ._items is None :
167+ self ._items = await self ._ctx ().get ("items" ) or []
173168 self ._items .extend (items )
174169
175170 async def pop_item (self ) -> TResponseInputItem | None :
176171 """Remove and return the most recent item from this session."""
177- await self ._load_items_if_needed ()
172+ if self ._items is None :
173+ self ._items = await self ._ctx ().get ("items" ) or []
178174 if self ._items :
179175 return self ._items .pop ()
180176 return None
@@ -244,6 +240,8 @@ class DurableRunner:
244240 async def run (
245241 starting_agent : Agent [TContext ],
246242 input : str | list [TResponseInputItem ],
243+ * ,
244+ use_restate_session : bool = False ,
247245 ** kwargs ,
248246 ) -> RunResult :
249247 """
@@ -254,7 +252,7 @@ async def run(
254252 """
255253
256254 # Set persisting model calls
257- llm_retry_opts = kwargs .get ("llm_retry_opts" , None )
255+ llm_retry_opts = kwargs .pop ("llm_retry_opts" , None )
258256 run_config = kwargs .pop ("run_config" , RunConfig ())
259257 run_config = dataclasses .replace (run_config , model_provider = DurableModelCalls (llm_retry_opts ))
260258
@@ -272,12 +270,19 @@ async def run(
272270 model_settings = model_settings ,
273271 )
274272
275- runner = DEFAULT_AGENT_RUNNER or AgentRunner ()
273+ # Use Restate session if requested, otherwise use provided session
274+ session = kwargs .pop ("session" , None )
275+ if use_restate_session :
276+ if session is not None :
277+ raise TerminalError ("When use_restate_session is True, session config cannot be provided." )
278+ session = RestateSession ()
279+
276280 try :
277- result = await runner .run (starting_agent = starting_agent , input = input , run_config = run_config , ** kwargs )
281+ result = await Runner .run (
282+ starting_agent = starting_agent , input = input , run_config = run_config , session = session , ** kwargs
283+ )
278284 finally :
279285 # Flush session items to Restate
280- session = kwargs .get ("session" , None )
281286 if session is not None and isinstance (session , RestateSession ):
282287 session .flush ()
283288
0 commit comments