Skip to content

Commit a32145b

Browse files
authored
OpenAI DurableRunner with LLM retry opts and lazy session flushing
2 parents a4765b0 + 73bfa10 commit a32145b

File tree

2 files changed

+133
-100
lines changed

2 files changed

+133
-100
lines changed

python/restate/ext/openai/__init__.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,34 @@
1212
This module contains the optional OpenAI integration for Restate.
1313
"""
1414

15-
from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors
15+
import typing
16+
17+
from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts
18+
from restate import ObjectContext, Context
19+
from restate.server_context import current_context
20+
21+
22+
def restate_object_context() -> ObjectContext:
23+
"""Get the current Restate ObjectContext."""
24+
ctx = current_context()
25+
if ctx is None:
26+
raise RuntimeError("No Restate context found.")
27+
return typing.cast(ObjectContext, ctx)
28+
29+
30+
def restate_context() -> Context:
31+
"""Get the current Restate Context."""
32+
ctx = current_context()
33+
if ctx is None:
34+
raise RuntimeError("No Restate context found.")
35+
return ctx
36+
1637

1738
__all__ = [
18-
"DurableModelCalls",
39+
"DurableRunner",
40+
"LlmRetryOpts",
41+
"restate_object_context",
42+
"restate_context",
1943
"continue_on_terminal_errors",
2044
"raise_terminal_errors",
21-
"Runner",
2245
]

python/restate/ext/openai/runner_wrapper.py

Lines changed: 107 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,60 @@
1212
This module contains the optional OpenAI integration for Restate.
1313
"""
1414

15-
import asyncio
1615
import dataclasses
17-
import typing
1816

1917
from agents import (
20-
Tool,
2118
Usage,
2219
Model,
2320
RunContextWrapper,
2421
AgentsException,
25-
Runner as OpenAIRunner,
2622
RunConfig,
2723
TContext,
2824
RunResult,
2925
Agent,
3026
ModelBehaviorError,
27+
ModelSettings,
28+
Runner,
3129
)
32-
3330
from agents.models.multi_provider import MultiProvider
3431
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
3532
from agents.memory.session import SessionABC
3633
from agents.items import TResponseInputItem
37-
from typing import List, Any
38-
from typing import AsyncIterator
39-
40-
from agents.tool import FunctionTool
41-
from agents.tool_context import ToolContext
34+
from datetime import timedelta
35+
from typing import List, Any, AsyncIterator, Optional, cast
4236
from pydantic import BaseModel
37+
4338
from restate.exceptions import SdkInternalBaseException
4439
from restate.extensions import current_context
45-
4640
from restate import RunOptions, ObjectContext, TerminalError
4741

4842

43+
@dataclasses.dataclass
44+
class LlmRetryOpts:
45+
max_attempts: Optional[int] = 10
46+
"""Max number of attempts (including the initial), before giving up.
47+
48+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
49+
max_duration: Optional[timedelta] = None
50+
"""Max duration of retries, before giving up.
51+
52+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
53+
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
54+
"""Initial interval for the first retry attempt.
55+
Retry interval will grow by a factor specified in `retry_interval_factor`.
56+
57+
If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
58+
max_retry_interval: Optional[timedelta] = None
59+
"""Max interval between retries.
60+
Retry interval will grow by a factor specified in `retry_interval_factor`.
61+
62+
The default is 10 seconds."""
63+
retry_interval_factor: Optional[float] = None
64+
"""Exponentiation factor to use when computing the next retry delay.
65+
66+
If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
67+
68+
4969
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
5070
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
5171
class RestateModelResponse(BaseModel):
@@ -71,23 +91,23 @@ class DurableModelCalls(MultiProvider):
7191
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
7292
"""
7393

74-
def __init__(self, max_retries: int | None = 3):
94+
def __init__(self, llm_retry_opts: LlmRetryOpts | None = None):
7595
super().__init__()
76-
self.max_retries = max_retries
96+
self.llm_retry_opts = llm_retry_opts
7797

7898
def get_model(self, model_name: str | None) -> Model:
79-
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)
99+
return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts)
80100

81101

82102
class RestateModelWrapper(Model):
83103
"""
84104
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
85105
"""
86106

87-
def __init__(self, model: Model, max_retries: int | None = 3):
107+
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None):
88108
self.model = model
89109
self.model_name = "RestateModelWrapper"
90-
self.max_retries = max_retries
110+
self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts()
91111

92112
async def get_response(self, *args, **kwargs) -> ModelResponse:
93113
async def call_llm() -> RestateModelResponse:
@@ -102,7 +122,17 @@ async def call_llm() -> RestateModelResponse:
102122
ctx = current_context()
103123
if ctx is None:
104124
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
105-
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
125+
result = await ctx.run_typed(
126+
"call LLM",
127+
call_llm,
128+
RunOptions(
129+
max_attempts=self.llm_retry_opts.max_attempts,
130+
max_duration=self.llm_retry_opts.max_duration,
131+
initial_retry_interval=self.llm_retry_opts.initial_retry_interval,
132+
max_retry_interval=self.llm_retry_opts.max_retry_interval,
133+
retry_interval_factor=self.llm_retry_opts.retry_interval_factor,
134+
),
135+
)
106136
# convert back to original ModelResponse
107137
return ModelResponse(
108138
output=result.output,
@@ -117,33 +147,41 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
117147
class RestateSession(SessionABC):
118148
"""Restate session implementation following the Session protocol."""
119149

150+
def __init__(self):
151+
self._items: List[TResponseInputItem] | None = None
152+
120153
def _ctx(self) -> ObjectContext:
121-
return typing.cast(ObjectContext, current_context())
154+
return cast(ObjectContext, current_context())
122155

123156
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
124157
"""Retrieve conversation history for this session."""
125-
current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
158+
if self._items is None:
159+
self._items = await self._ctx().get("items") or []
126160
if limit is not None:
127-
return current_items[-limit:]
128-
return current_items
161+
return self._items[-limit:]
162+
return self._items.copy()
129163

130164
async def add_items(self, items: List[TResponseInputItem]) -> None:
131165
"""Store new items for this session."""
132-
# Your implementation here
133-
current_items = await self.get_items() or []
134-
self._ctx().set("items", current_items + items)
166+
if self._items is None:
167+
self._items = await self._ctx().get("items") or []
168+
self._items.extend(items)
135169

136170
async def pop_item(self) -> TResponseInputItem | None:
137171
"""Remove and return the most recent item from this session."""
138-
current_items = await self.get_items() or []
139-
if current_items:
140-
item = current_items.pop()
141-
self._ctx().set("items", current_items)
142-
return item
172+
if self._items is None:
173+
self._items = await self._ctx().get("items") or []
174+
if self._items:
175+
return self._items.pop()
143176
return None
144177

178+
def flush(self) -> None:
179+
"""Flush the session items to the context."""
180+
self._ctx().set("items", self._items)
181+
145182
async def clear_session(self) -> None:
146183
"""Clear all items for this session."""
184+
self._items = []
147185
self._ctx().clear("items")
148186

149187

@@ -189,7 +227,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio
189227
raise error
190228

191229

192-
class Runner:
230+
class DurableRunner:
193231
"""
194232
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
195233
@@ -201,83 +239,55 @@ class Runner:
201239
@staticmethod
202240
async def run(
203241
starting_agent: Agent[TContext],
204-
disable_tool_autowrapping: bool = False,
205-
*args: typing.Any,
206-
run_config: RunConfig | None = None,
242+
input: str | list[TResponseInputItem],
243+
*,
244+
use_restate_session: bool = False,
207245
**kwargs,
208246
) -> RunResult:
209247
"""
210248
Run an agent with automatic Restate configuration.
211249
250+
Args:
251+
use_restate_session: If True, creates a RestateSession for conversation persistence.
252+
Requires running within a Restate Virtual Object context.
253+
212254
Returns:
213255
The result from Runner.run
214256
"""
215257

216-
current_run_config = run_config or RunConfig()
217-
new_run_config = dataclasses.replace(
218-
current_run_config,
219-
model_provider=DurableModelCalls(),
220-
)
221-
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
222-
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)
223-
224-
225-
def sequentialize_and_wrap_tools(
226-
agent: Agent[TContext],
227-
disable_tool_autowrapping: bool,
228-
) -> Agent[TContext]:
229-
"""
230-
Wrap the tools of an agent to use the Restate error handling.
231-
232-
Returns:
233-
A new agent with wrapped tools.
234-
"""
258+
# Set persisting model calls
259+
llm_retry_opts = kwargs.pop("llm_retry_opts", None)
260+
run_config = kwargs.pop("run_config", RunConfig())
261+
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))
235262

236-
# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237-
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
238-
sequential_tools_lock = asyncio.Lock()
239-
wrapped_tools: list[Tool] = []
240-
for tool in agent.tools:
241-
if isinstance(tool, FunctionTool):
242-
243-
def create_wrapper(captured_tool):
244-
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
245-
await sequential_tools_lock.acquire()
246-
247-
async def invoke():
248-
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
249-
# Ensure Pydantic objects are serialized to dict for LLM compatibility
250-
if hasattr(result, "model_dump"):
251-
return result.model_dump()
252-
elif hasattr(result, "dict"):
253-
return result.dict()
254-
return result
255-
256-
try:
257-
if disable_tool_autowrapping:
258-
return await invoke()
259-
260-
ctx = current_context()
261-
if ctx is None:
262-
raise RuntimeError(
263-
"No current Restate context found, make sure to run inside a Restate handler"
264-
)
265-
return await ctx.run_typed(captured_tool.name, invoke)
266-
finally:
267-
sequential_tools_lock.release()
268-
269-
return on_invoke_tool_wrapper
270-
271-
wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool)))
263+
# Disable parallel tool calls
264+
model_settings = run_config.model_settings
265+
if model_settings is None:
266+
model_settings = ModelSettings(parallel_tool_calls=False)
272267
else:
273-
wrapped_tools.append(tool)
268+
model_settings = dataclasses.replace(
269+
model_settings,
270+
parallel_tool_calls=False,
271+
)
272+
run_config = dataclasses.replace(
273+
run_config,
274+
model_settings=model_settings,
275+
)
276+
277+
# Use Restate session if requested, otherwise use provided session
278+
session = kwargs.pop("session", None)
279+
if use_restate_session:
280+
if session is not None:
281+
raise TerminalError("When use_restate_session is True, session config cannot be provided.")
282+
session = RestateSession()
274283

275-
handoffs_with_wrapped_tools = []
276-
for handoff in agent.handoffs:
277-
# recursively wrap tools in handoff agents
278-
handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore
284+
try:
285+
result = await Runner.run(
286+
starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs
287+
)
288+
finally:
289+
# Flush session items to Restate
290+
if session is not None and isinstance(session, RestateSession):
291+
session.flush()
279292

280-
return agent.clone(
281-
tools=wrapped_tools,
282-
handoffs=handoffs_with_wrapped_tools,
283-
)
293+
return result

0 commit comments

Comments
 (0)