Skip to content

Commit c7b5fe9

Browse files
committed
Integrate turnstile into OpenAI ext
1 parent cb252f0 commit c7b5fe9

File tree

4 files changed

+273
-98
lines changed

4 files changed

+273
-98
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
This module contains the optional OpenAI integration for Restate.
13+
"""
14+
15+
import dataclasses
16+
17+
from agents import (
18+
Usage,
19+
)
20+
from agents.items import TResponseOutputItem
21+
from agents.items import TResponseInputItem
22+
from datetime import timedelta
23+
from typing import Optional
24+
from pydantic import BaseModel
25+
26+
from restate.ext.turnstile import Turnstile
27+
28+
29+
class State:
30+
__slots__ = ("turnstile",)
31+
32+
def __init__(self) -> None:
33+
self.turnstile = Turnstile([])
34+
35+
36+
@dataclasses.dataclass
37+
class LlmRetryOpts:
38+
max_attempts: Optional[int] = 10
39+
"""Max number of attempts (including the initial), before giving up.
40+
41+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
42+
max_duration: Optional[timedelta] = None
43+
"""Max duration of retries, before giving up.
44+
45+
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
46+
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
47+
"""Initial interval for the first retry attempt.
48+
Retry interval will grow by a factor specified in `retry_interval_factor`.
49+
50+
If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
51+
max_retry_interval: Optional[timedelta] = None
52+
"""Max interval between retries.
53+
Retry interval will grow by a factor specified in `retry_interval_factor`.
54+
55+
The default is 10 seconds."""
56+
retry_interval_factor: Optional[float] = None
57+
"""Exponentiation factor to use when computing the next retry delay.
58+
59+
If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
60+
61+
62+
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
63+
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
64+
class RestateModelResponse(BaseModel):
65+
output: list[TResponseOutputItem]
66+
"""A list of outputs (messages, tool calls, etc) generated by the model"""
67+
68+
usage: Usage
69+
"""The usage information for the response."""
70+
71+
response_id: str | None
72+
"""An ID for the response which can be used to refer to the response in subsequent calls to the
73+
model. Not supported by all model providers.
74+
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
75+
be passed to `Runner.run`.
76+
"""
77+
78+
def to_input_items(self) -> list[TResponseInputItem]:
79+
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore

python/restate/ext/openai/runner_wrapper.py

Lines changed: 23 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import dataclasses
1616

1717
from agents import (
18-
Usage,
1918
Model,
2019
RunContextWrapper,
2120
AgentsException,
@@ -28,84 +27,43 @@
2827
Runner,
2928
)
3029
from agents.models.multi_provider import MultiProvider
31-
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
32-
from agents.memory.session import SessionABC
30+
from agents.items import TResponseStreamEvent, ModelResponse
3331
from agents.items import TResponseInputItem
34-
from datetime import timedelta
35-
from typing import List, Any, AsyncIterator, Optional, cast
36-
from pydantic import BaseModel
32+
from typing import Any, AsyncIterator
3733

3834
from restate.exceptions import SdkInternalBaseException
35+
from restate.ext.turnstile import Turnstile
3936
from restate.extensions import current_context
40-
from restate import RunOptions, ObjectContext, TerminalError
37+
from restate import RunOptions, TerminalError
4138

42-
43-
@dataclasses.dataclass
44-
class LlmRetryOpts:
45-
max_attempts: Optional[int] = 10
46-
"""Max number of attempts (including the initial), before giving up.
47-
48-
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
49-
max_duration: Optional[timedelta] = None
50-
"""Max duration of retries, before giving up.
51-
52-
When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
53-
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
54-
"""Initial interval for the first retry attempt.
55-
Retry interval will grow by a factor specified in `retry_interval_factor`.
56-
57-
If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
58-
max_retry_interval: Optional[timedelta] = None
59-
"""Max interval between retries.
60-
Retry interval will grow by a factor specified in `retry_interval_factor`.
61-
62-
The default is 10 seconds."""
63-
retry_interval_factor: Optional[float] = None
64-
"""Exponentiation factor to use when computing the next retry delay.
65-
66-
If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""
67-
68-
69-
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
70-
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
71-
class RestateModelResponse(BaseModel):
72-
output: list[TResponseOutputItem]
73-
"""A list of outputs (messages, tool calls, etc) generated by the model"""
74-
75-
usage: Usage
76-
"""The usage information for the response."""
77-
78-
response_id: str | None
79-
"""An ID for the response which can be used to refer to the response in subsequent calls to the
80-
model. Not supported by all model providers.
81-
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
82-
be passed to `Runner.run`.
83-
"""
84-
85-
def to_input_items(self) -> list[TResponseInputItem]:
86-
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore
39+
from .utils import get_function_call_ids, wrap_agent_tools
40+
from .models import LlmRetryOpts, RestateModelResponse, State
41+
from .session import RestateSession
8742

8843

8944
class DurableModelCalls(MultiProvider):
9045
"""
9146
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
9247
"""
9348

94-
def __init__(self, llm_retry_opts: LlmRetryOpts | None = None):
49+
def __init__(self, state: State, llm_retry_opts: LlmRetryOpts | None = None):
9550
super().__init__()
9651
self.llm_retry_opts = llm_retry_opts
52+
self.state = state
9753

9854
def get_model(self, model_name: str | None) -> Model:
99-
return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts)
55+
model = super().get_model(model_name or None)
56+
return RestateModelWrapper(model, self.state, self.llm_retry_opts)
10057

10158

10259
class RestateModelWrapper(Model):
10360
"""
10461
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
10562
"""
10663

107-
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None):
64+
def __init__(self, model: Model, state: State, llm_retry_opts: LlmRetryOpts | None = None):
10865
self.model = model
66+
self.state = state
10967
self.model_name = "RestateModelWrapper"
11068
self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts()
11169

@@ -133,6 +91,10 @@ async def call_llm() -> RestateModelResponse:
13391
retry_interval_factor=self.llm_retry_opts.retry_interval_factor,
13492
),
13593
)
94+
# collect function call IDs, to
95+
ids = get_function_call_ids(result.output)
96+
self.state.turnstile = Turnstile(ids)
97+
13698
# convert back to original ModelResponse
13799
return ModelResponse(
138100
output=result.output,
@@ -144,47 +106,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
144106
raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.")
145107

146108

147-
class RestateSession(SessionABC):
148-
"""Restate session implementation following the Session protocol."""
149-
150-
def __init__(self):
151-
self._items: List[TResponseInputItem] | None = None
152-
153-
def _ctx(self) -> ObjectContext:
154-
return cast(ObjectContext, current_context())
155-
156-
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
157-
"""Retrieve conversation history for this session."""
158-
if self._items is None:
159-
self._items = await self._ctx().get("items") or []
160-
if limit is not None:
161-
return self._items[-limit:]
162-
return self._items.copy()
163-
164-
async def add_items(self, items: List[TResponseInputItem]) -> None:
165-
"""Store new items for this session."""
166-
if self._items is None:
167-
self._items = await self._ctx().get("items") or []
168-
self._items.extend(items)
169-
170-
async def pop_item(self) -> TResponseInputItem | None:
171-
"""Remove and return the most recent item from this session."""
172-
if self._items is None:
173-
self._items = await self._ctx().get("items") or []
174-
if self._items:
175-
return self._items.pop()
176-
return None
177-
178-
def flush(self) -> None:
179-
"""Flush the session items to the context."""
180-
self._ctx().set("items", self._items)
181-
182-
async def clear_session(self) -> None:
183-
"""Clear all items for this session."""
184-
self._items = []
185-
self._ctx().clear("items")
186-
187-
188109
class AgentsTerminalException(AgentsException, TerminalError):
189110
"""Exception that is both an AgentsException and a restate.TerminalError."""
190111

@@ -255,10 +176,13 @@ async def run(
255176
The result from Runner.run
256177
"""
257178

179+
# execution state
180+
state = State()
181+
258182
# Set persisting model calls
259183
llm_retry_opts = kwargs.pop("llm_retry_opts", None)
260184
run_config = kwargs.pop("run_config", RunConfig())
261-
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))
185+
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(state, llm_retry_opts))
262186

263187
# Disable parallel tool calls
264188
model_settings = run_config.model_settings
@@ -281,9 +205,10 @@ async def run(
281205
raise TerminalError("When use_restate_session is True, session config cannot be provided.")
282206
session = RestateSession()
283207

208+
agent = wrap_agent_tools(starting_agent, state)
284209
try:
285210
result = await Runner.run(
286-
starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs
211+
starting_agent=agent, input=input, run_config=run_config, session=session, **kwargs
287212
)
288213
finally:
289214
# Flush session items to Restate
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
This module contains the optional OpenAI integration for Restate.
13+
"""
14+
15+
from agents.memory.session import SessionABC
16+
from agents.items import TResponseInputItem
17+
from typing import List, cast
18+
19+
from restate.extensions import current_context
20+
from restate import ObjectContext
21+
22+
23+
class RestateSession(SessionABC):
24+
"""Restate session implementation following the Session protocol."""
25+
26+
def __init__(self):
27+
self._items: List[TResponseInputItem] | None = None
28+
29+
def _ctx(self) -> ObjectContext:
30+
return cast(ObjectContext, current_context())
31+
32+
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
33+
"""Retrieve conversation history for this session."""
34+
if self._items is None:
35+
self._items = await self._ctx().get("items") or []
36+
if limit is not None:
37+
return self._items[-limit:]
38+
return self._items.copy()
39+
40+
async def add_items(self, items: List[TResponseInputItem]) -> None:
41+
"""Store new items for this session."""
42+
if self._items is None:
43+
self._items = await self._ctx().get("items") or []
44+
self._items.extend(items)
45+
46+
async def pop_item(self) -> TResponseInputItem | None:
47+
"""Remove and return the most recent item from this session."""
48+
if self._items is None:
49+
self._items = await self._ctx().get("items") or []
50+
if self._items:
51+
return self._items.pop()
52+
return None
53+
54+
def flush(self) -> None:
55+
"""Flush the session items to the context."""
56+
self._ctx().set("items", self._items)
57+
58+
async def clear_session(self) -> None:
59+
"""Clear all items for this session."""
60+
self._items = []
61+
self._ctx().clear("items")

0 commit comments

Comments
 (0)