Skip to content

Commit 76bdd71

Browse files
committed
Merge remote-tracking branch 'origin/main' into openai/test_features
2 parents d296b00 + 8b727e5 commit 76bdd71

File tree

7 files changed

+170
-42
lines changed

7 files changed

+170
-42
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
"""
1010

1111
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
12-
from temporalio.contrib.openai_agents.temporal_openai_agents import (
12+
from temporalio.contrib.openai_agents._temporal_openai_agents import (
1313
OpenAIAgentsPlugin,
1414
TestModel,
1515
TestModelProvider,
1616
)
17+
from temporalio.contrib.openai_agents._trace_interceptor import (
18+
OpenAIAgentsTracingInterceptor,
19+
)
1720

1821
from . import workflow
1922

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"""
55

66
import enum
7-
import json
87
from dataclasses import dataclass
8+
from datetime import timedelta
99
from typing import Any, Optional, Union, cast
1010

1111
from agents import (
@@ -20,19 +20,23 @@
2020
ModelResponse,
2121
ModelSettings,
2222
ModelTracing,
23+
OpenAIProvider,
2324
RunContextWrapper,
2425
Tool,
2526
TResponseInputItem,
2627
UserError,
2728
WebSearchTool,
2829
)
29-
from agents.models.multi_provider import MultiProvider
30+
from openai import (
31+
APIStatusError,
32+
AsyncOpenAI,
33+
)
3034
from openai.types.responses.tool_param import Mcp
31-
from pydantic_core import to_json, to_jsonable_python
3235
from typing_extensions import Required, TypedDict
3336

3437
from temporalio import activity
3538
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
39+
from temporalio.exceptions import ApplicationError
3640

3741

3842
@dataclass
@@ -136,11 +140,15 @@ class ActivityModelInput(TypedDict, total=False):
136140

137141

138142
class ModelActivity:
139-
"""Class wrapper for model invocation activities to allow model customization."""
143+
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
144+
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
145+
"""
140146

141147
def __init__(self, model_provider: Optional[ModelProvider] = None):
142148
"""Initialize the activity with a model provider."""
143-
self._model_provider = model_provider or MultiProvider()
149+
self._model_provider = model_provider or OpenAIProvider(
150+
openai_client=AsyncOpenAI(max_retries=0)
151+
)
144152

145153
@activity.defn
146154
@_auto_heartbeater
@@ -194,14 +202,51 @@ def make_tool(tool: ToolInput) -> Tool:
194202
)
195203
for x in input.get("handoffs", [])
196204
]
197-
return await model.get_response(
198-
system_instructions=input.get("system_instructions"),
199-
input=input["input"],
200-
model_settings=input["model_settings"],
201-
tools=tools,
202-
output_schema=input.get("output_schema"),
203-
handoffs=handoffs,
204-
tracing=ModelTracing(input["tracing"]),
205-
previous_response_id=input.get("previous_response_id"),
206-
prompt=input.get("prompt"),
207-
)
205+
206+
try:
207+
return await model.get_response(
208+
system_instructions=input.get("system_instructions"),
209+
input=input["input"],
210+
model_settings=input["model_settings"],
211+
tools=tools,
212+
output_schema=input.get("output_schema"),
213+
handoffs=handoffs,
214+
tracing=ModelTracing(input["tracing"]),
215+
previous_response_id=input.get("previous_response_id"),
216+
prompt=input.get("prompt"),
217+
)
218+
except APIStatusError as e:
219+
# Listen to server hints
220+
retry_after = None
221+
retry_after_ms_header = e.response.headers.get("retry-after-ms")
222+
if retry_after_ms_header is not None:
223+
retry_after = timedelta(milliseconds=float(retry_after_ms_header))
224+
225+
if retry_after is None:
226+
retry_after_header = e.response.headers.get("retry-after")
227+
if retry_after_header is not None:
228+
retry_after = timedelta(seconds=float(retry_after_header))
229+
230+
should_retry_header = e.response.headers.get("x-should-retry")
231+
if should_retry_header == "true":
232+
raise e
233+
if should_retry_header == "false":
234+
raise ApplicationError(
235+
"Non retryable OpenAI error",
236+
non_retryable=True,
237+
next_retry_delay=retry_after,
238+
) from e
239+
240+
# Specifically retryable status codes
241+
if e.response.status_code in [408, 409, 429, 500]:
242+
raise ApplicationError(
243+
"Retryable OpenAI status code",
244+
non_retryable=False,
245+
next_retry_delay=retry_after,
246+
) from e
247+
248+
raise ApplicationError(
249+
"Non retryable OpenAI status code",
250+
non_retryable=True,
251+
next_retry_delay=retry_after,
252+
) from e

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class ModelActivityParameters:
2020
task_queue: Optional[str] = None
2121
"""Specific task queue to use for model activities."""
2222

23-
schedule_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
23+
schedule_to_close_timeout: Optional[timedelta] = None
2424
"""Maximum time from scheduling to completion."""
2525

2626
schedule_to_start_timeout: Optional[timedelta] = None
2727
"""Maximum time from scheduling to starting."""
2828

29-
start_to_close_timeout: Optional[timedelta] = None
29+
start_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
3030
"""Maximum time for the activity to complete."""
3131

3232
heartbeat_timeout: Optional[timedelta] = None

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
from dataclasses import replace
23
from typing import Any, Union
34

@@ -7,6 +8,7 @@
78
RunResult,
89
RunResultStreaming,
910
TContext,
11+
Tool,
1012
TResponseInputItem,
1113
)
1214
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
@@ -42,6 +44,13 @@ async def run(
4244
**kwargs,
4345
)
4446

47+
tool_types = typing.get_args(Tool)
48+
for t in starting_agent.tools:
49+
if not isinstance(t, tool_types):
50+
raise ValueError(
51+
"Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool."
52+
)
53+
4554
context = kwargs.get("context")
4655
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
4756
hooks = kwargs.get("hooks")
@@ -63,16 +72,15 @@ async def run(
6372
),
6473
)
6574

66-
with workflow.unsafe.imports_passed_through():
67-
return await self._runner.run(
68-
starting_agent=starting_agent,
69-
input=input,
70-
context=context,
71-
max_turns=max_turns,
72-
hooks=hooks,
73-
run_config=updated_run_config,
74-
previous_response_id=previous_response_id,
75-
)
75+
return await self._runner.run(
76+
starting_agent=starting_agent,
77+
input=input,
78+
context=context,
79+
max_turns=max_turns,
80+
hooks=hooks,
81+
run_config=updated_run_config,
82+
previous_response_id=previous_response_id,
83+
)
7684

7785
def run_sync(
7886
self,

temporalio/contrib/openai_agents/temporal_openai_agents.py renamed to temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import AsyncIterator, Callable, Optional, Union
56

67
from agents import (
@@ -39,7 +40,7 @@
3940

4041
@contextmanager
4142
def set_open_ai_agent_temporal_overrides(
42-
model_params: Optional[ModelActivityParameters] = None,
43+
model_params: ModelActivityParameters,
4344
auto_close_tracing_in_workflows: bool = False,
4445
):
4546
"""Configure Temporal-specific overrides for OpenAI agents.
@@ -69,14 +70,6 @@ def set_open_ai_agent_temporal_overrides(
6970
if model_params is None:
7071
model_params = ModelActivityParameters()
7172

72-
if (
73-
not model_params.start_to_close_timeout
74-
and not model_params.schedule_to_close_timeout
75-
):
76-
raise ValueError(
77-
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
78-
)
79-
8073
previous_runner = get_default_agent_runner()
8174
previous_trace_provider = get_trace_provider()
8275
provider = TemporalTraceProvider(
@@ -208,6 +201,22 @@ def __init__(
208201
model_provider: Optional model provider for custom model implementations.
209202
Useful for testing or custom model integrations.
210203
"""
204+
if model_params is None:
205+
model_params = ModelActivityParameters()
206+
207+
# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
208+
# Other providers will need to define their own.
209+
if (
210+
model_params.start_to_close_timeout is None
211+
and model_params.schedule_to_close_timeout is None
212+
):
213+
if model_provider is None:
214+
model_params.start_to_close_timeout = timedelta(seconds=60)
215+
else:
216+
raise ValueError(
217+
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
218+
)
219+
211220
self._model_params = model_params
212221
self._model_provider = model_provider
213222

tests/contrib/openai_agents/test_openai.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
ToolCallItem,
4646
ToolCallOutputItem,
4747
)
48-
from openai import AsyncOpenAI, BaseModel
48+
from openai import APIStatusError, AsyncOpenAI, BaseModel
4949
from openai.types.responses import (
5050
ResponseCodeInterpreterToolCall,
5151
ResponseFileSearchToolCall,
@@ -60,16 +60,18 @@
6060
from openai.types.responses.response_prompt_param import ResponsePromptParam
6161
from pydantic import ConfigDict, Field, TypeAdapter
6262

63+
import temporalio.api.cloud.namespace.v1
6364
from temporalio import activity, workflow
6465
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
66+
from temporalio.common import RetryPolicy, SearchAttributeValueType
6567
from temporalio.contrib import openai_agents
6668
from temporalio.contrib.openai_agents import (
6769
ModelActivityParameters,
6870
TestModel,
6971
TestModelProvider,
7072
)
7173
from temporalio.contrib.pydantic import pydantic_data_converter
72-
from temporalio.exceptions import CancelledError
74+
from temporalio.exceptions import ApplicationError, CancelledError
7375
from temporalio.testing import WorkflowEnvironment
7476
from tests.contrib.openai_agents.research_agents.research_manager import (
7577
ResearchManager,
@@ -1791,6 +1793,66 @@ async def test_response_serialization():
17911793
encoded = await pydantic_data_converter.encode([model_response])
17921794

17931795

1796+
async def assert_status_retry_behavior(status: int, client: Client, should_retry: bool):
1797+
def status_error(status: int):
1798+
with workflow.unsafe.imports_passed_through():
1799+
with workflow.unsafe.sandbox_unrestricted():
1800+
import httpx
1801+
raise APIStatusError(
1802+
message="Something went wrong.",
1803+
response=httpx.Response(
1804+
status_code=status, request=httpx.Request("GET", url="")
1805+
),
1806+
body=None,
1807+
)
1808+
1809+
new_config = client.config()
1810+
new_config["plugins"] = [
1811+
openai_agents.OpenAIAgentsPlugin(
1812+
model_params=ModelActivityParameters(
1813+
retry_policy=RetryPolicy(maximum_attempts=2),
1814+
),
1815+
model_provider=TestModelProvider(TestModel(lambda: status_error(status))),
1816+
)
1817+
]
1818+
client = Client(**new_config)
1819+
1820+
async with new_worker(
1821+
client,
1822+
HelloWorldAgent,
1823+
) as worker:
1824+
workflow_handle = await client.start_workflow(
1825+
HelloWorldAgent.run,
1826+
"Input",
1827+
id=f"workflow-tool-{uuid.uuid4()}",
1828+
task_queue=worker.task_queue,
1829+
execution_timeout=timedelta(seconds=10),
1830+
)
1831+
with pytest.raises(WorkflowFailureError) as e:
1832+
await workflow_handle.result()
1833+
1834+
found = False
1835+
async for event in workflow_handle.fetch_history_events():
1836+
if event.HasField("activity_task_started_event_attributes"):
1837+
found = True
1838+
if should_retry:
1839+
assert event.activity_task_started_event_attributes.attempt == 2
1840+
else:
1841+
assert event.activity_task_started_event_attributes.attempt == 1
1842+
assert found
1843+
1844+
1845+
async def test_exception_handling(client: Client):
1846+
await assert_status_retry_behavior(408, client, should_retry=True)
1847+
await assert_status_retry_behavior(409, client, should_retry=True)
1848+
await assert_status_retry_behavior(429, client, should_retry=True)
1849+
await assert_status_retry_behavior(500, client, should_retry=True)
1850+
1851+
await assert_status_retry_behavior(400, client, should_retry=False)
1852+
await assert_status_retry_behavior(403, client, should_retry=False)
1853+
await assert_status_retry_behavior(404, client, should_retry=False)
1854+
1855+
17941856
async def test_lite_llm(client: Client):
17951857
if not os.environ.get("OPENAI_API_KEY"):
17961858
pytest.skip("No openai API key")

tests/contrib/openai_agents/test_openai_replay.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44

55
from temporalio.client import WorkflowHistory
6-
from temporalio.contrib.openai_agents.temporal_openai_agents import (
6+
from temporalio.contrib.openai_agents import ModelActivityParameters
7+
from temporalio.contrib.openai_agents._temporal_openai_agents import (
78
set_open_ai_agent_temporal_overrides,
89
)
910
from temporalio.contrib.pydantic import pydantic_data_converter
@@ -35,7 +36,7 @@ async def test_replay(file_name: str) -> None:
3536
with (Path(__file__).with_name("histories") / file_name).open("r") as f:
3637
history_json = f.read()
3738

38-
with set_open_ai_agent_temporal_overrides():
39+
with set_open_ai_agent_temporal_overrides(ModelActivityParameters()):
3940
await Replayer(
4041
workflows=[
4142
ResearchWorkflow,

0 commit comments

Comments
 (0)