Skip to content

Commit bc9a5b6

Browse files
committed
Address review comments
1 parent 2372241 commit bc9a5b6

File tree

2 files changed

+24
-26
lines changed

2 files changed

+24
-26
lines changed

python/restate/ext/openai/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,7 @@
1414

1515
import typing
1616

17-
from .runner_wrapper import (
18-
DurableRunner,
19-
continue_on_terminal_errors,
20-
raise_terminal_errors,
21-
RestateSession,
22-
LlmRetryOpts
23-
)
17+
from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts
2418
from restate import ObjectContext, Context
2519
from restate.server_context import current_context
2620

@@ -43,7 +37,6 @@ def restate_context() -> Context:
4337

4438
__all__ = [
4539
"DurableRunner",
46-
"RestateSession",
4740
"LlmRetryOpts",
4841
"restate_object_context",
4942
"restate_context",

python/restate/ext/openai/runner_wrapper.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,12 @@
2525
Agent,
2626
ModelBehaviorError,
2727
ModelSettings,
28+
Runner,
2829
)
2930
from agents.models.multi_provider import MultiProvider
3031
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
3132
from agents.memory.session import SessionABC
3233
from agents.items import TResponseInputItem
33-
from agents.run import (
34-
AgentRunner,
35-
DEFAULT_AGENT_RUNNER,
36-
)
3734
from datetime import timedelta
3835
from typing import List, Any, AsyncIterator, Optional, cast
3936
from pydantic import BaseModel
@@ -42,6 +39,7 @@
4239
from restate.extensions import current_context
4340
from restate import RunOptions, ObjectContext, TerminalError
4441

42+
4543
@dataclasses.dataclass
4644
class LlmRetryOpts:
4745
max_attempts: Optional[int] = 10
@@ -106,10 +104,10 @@ class RestateModelWrapper(Model):
106104
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
107105
"""
108106

109-
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = LlmRetryOpts()):
107+
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None):
110108
self.model = model
111109
self.model_name = "RestateModelWrapper"
112-
self.llm_retry_opts = llm_retry_opts
110+
self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts()
113111

114112
async def get_response(self, *args, **kwargs) -> ModelResponse:
115113
async def call_llm() -> RestateModelResponse:
@@ -155,26 +153,24 @@ def __init__(self):
155153
def _ctx(self) -> ObjectContext:
156154
return cast(ObjectContext, current_context())
157155

158-
async def _load_items_if_needed(self) -> None:
159-
"""Load items from context if not already loaded."""
160-
if self._items is None:
161-
self._items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
162-
163156
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
164157
"""Retrieve conversation history for this session."""
165-
await self._load_items_if_needed()
158+
if self._items is None:
159+
self._items = await self._ctx().get("items") or []
166160
if limit is not None:
167161
return self._items[-limit:]
168162
return self._items.copy()
169163

170164
async def add_items(self, items: List[TResponseInputItem]) -> None:
171165
"""Store new items for this session."""
172-
await self._load_items_if_needed()
166+
if self._items is None:
167+
self._items = await self._ctx().get("items") or []
173168
self._items.extend(items)
174169

175170
async def pop_item(self) -> TResponseInputItem | None:
176171
"""Remove and return the most recent item from this session."""
177-
await self._load_items_if_needed()
172+
if self._items is None:
173+
self._items = await self._ctx().get("items") or []
178174
if self._items:
179175
return self._items.pop()
180176
return None
@@ -244,6 +240,8 @@ class DurableRunner:
244240
async def run(
245241
starting_agent: Agent[TContext],
246242
input: str | list[TResponseInputItem],
243+
*,
244+
use_restate_session: bool = False,
247245
**kwargs,
248246
) -> RunResult:
249247
"""
@@ -254,7 +252,7 @@ async def run(
254252
"""
255253

256254
# Set persisting model calls
257-
llm_retry_opts = kwargs.get("llm_retry_opts", None)
255+
llm_retry_opts = kwargs.pop("llm_retry_opts", None)
258256
run_config = kwargs.pop("run_config", RunConfig())
259257
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))
260258

@@ -272,12 +270,19 @@ async def run(
272270
model_settings=model_settings,
273271
)
274272

275-
runner = DEFAULT_AGENT_RUNNER or AgentRunner()
273+
# Use Restate session if requested, otherwise use provided session
274+
session = kwargs.pop("session", None)
275+
if use_restate_session:
276+
if session is not None:
277+
raise TerminalError("When use_restate_session is True, session config cannot be provided.")
278+
session = RestateSession()
279+
276280
try:
277-
result = await runner.run(starting_agent=starting_agent, input=input, run_config=run_config, **kwargs)
281+
result = await Runner.run(
282+
starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs
283+
)
278284
finally:
279285
# Flush session items to Restate
280-
session = kwargs.get("session", None)
281286
if session is not None and isinstance(session, RestateSession):
282287
session.flush()
283288

0 commit comments

Comments
 (0)