Skip to content

Commit bd6a4f8

Browse files
committed
Merge remote-tracking branch 'origin/main' into openai/heartbeat
2 parents 8b94f6d + 8b727e5 commit bd6a4f8

26 files changed

+1348
-866
lines changed

pyproject.toml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ opentelemetry = [
2626
]
2727
pydantic = ["pydantic>=2.0.0,<3"]
2828
openai-agents = [
29-
"openai-agents >= 0.1,<0.2",
29+
"openai-agents >= 0.2.3,<0.3",
3030
"eval-type-backport>=0.2.2; python_version < '3.10'"
3131
]
3232

@@ -165,6 +165,7 @@ reportAny = "none"
165165
reportCallInDefaultInitializer = "none"
166166
reportExplicitAny = "none"
167167
reportIgnoreCommentWithoutRule = "none"
168+
reportImplicitAbstractClass = "none"
168169
reportImplicitOverride = "none"
169170
reportImplicitStringConcatenation = "none"
170171
reportImportCycles = "none"
@@ -184,11 +185,6 @@ exclude = [
184185
"temporalio/bridge/proto",
185186
"tests/worker/workflow_sandbox/testmodules/proto",
186187
"temporalio/bridge/worker.py",
187-
"temporalio/contrib/opentelemetry.py",
188-
"temporalio/contrib/pydantic.py",
189-
"temporalio/converter.py",
190-
"temporalio/testing/_workflow.py",
191-
"temporalio/worker/_activity.py",
192188
"temporalio/worker/_replayer.py",
193189
"temporalio/worker/_worker.py",
194190
"temporalio/worker/workflow_sandbox/_importer.py",
@@ -203,9 +199,7 @@ exclude = [
203199
"tests/contrib/pydantic/workflows.py",
204200
"tests/test_converter.py",
205201
"tests/test_service.py",
206-
"tests/test_workflow.py",
207202
"tests/worker/test_activity.py",
208-
"tests/worker/test_workflow.py",
209203
"tests/worker/workflow_sandbox/test_importer.py",
210204
"tests/worker/workflow_sandbox/test_restrictions.py",
211205
# TODO: these pass locally but fail in CI with

temporalio/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2889,7 +2889,7 @@ def _from_raw_info(
28892889
cls,
28902890
info: temporalio.api.workflow.v1.WorkflowExecutionInfo,
28912891
converter: temporalio.converter.DataConverter,
2892-
**additional_fields,
2892+
**additional_fields: Any,
28932893
) -> WorkflowExecution:
28942894
return cls(
28952895
close_time=info.close_time.ToDatetime().replace(tzinfo=timezone.utc)

temporalio/contrib/openai_agents/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,22 @@
88
Use with caution in production environments.
99
"""
1010

11-
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
1211
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
13-
from temporalio.contrib.openai_agents._trace_interceptor import (
14-
OpenAIAgentsTracingInterceptor,
15-
)
16-
from temporalio.contrib.openai_agents.temporal_openai_agents import (
12+
from temporalio.contrib.openai_agents._temporal_openai_agents import (
13+
OpenAIAgentsPlugin,
1714
TestModel,
1815
TestModelProvider,
19-
set_open_ai_agent_temporal_overrides,
16+
)
17+
from temporalio.contrib.openai_agents._trace_interceptor import (
18+
OpenAIAgentsTracingInterceptor,
2019
)
2120

2221
from . import workflow
2322

2423
__all__ = [
25-
"ModelActivity",
24+
"OpenAIAgentsPlugin",
2625
"ModelActivityParameters",
2726
"workflow",
28-
"set_open_ai_agent_temporal_overrides",
29-
"OpenAIAgentsTracingInterceptor",
3027
"TestModel",
3128
"TestModelProvider",
3229
]

temporalio/contrib/openai_agents/_heartbeat_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def _auto_heartbeater(fn: F) -> F:
1111
# Propagate type hints from the original callable.
1212
@wraps(fn)
13-
async def wrapper(*args, **kwargs):
13+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
1414
heartbeat_timeout = activity.info().heartbeat_timeout
1515
heartbeat_task = None
1616
if heartbeat_timeout:

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import enum
77
import json
88
from dataclasses import dataclass
9+
from datetime import timedelta
910
from typing import Any, Optional, Union, cast
1011

1112
from agents import (
@@ -17,17 +18,25 @@
1718
ModelResponse,
1819
ModelSettings,
1920
ModelTracing,
21+
OpenAIProvider,
2022
RunContextWrapper,
2123
Tool,
2224
TResponseInputItem,
2325
UserError,
2426
WebSearchTool,
2527
)
2628
from agents.models.multi_provider import MultiProvider
29+
from openai import (
30+
APIStatusError,
31+
AsyncOpenAI,
32+
AuthenticationError,
33+
PermissionDeniedError,
34+
)
2735
from typing_extensions import Required, TypedDict
2836

2937
from temporalio import activity
3038
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
39+
from temporalio.exceptions import ApplicationError
3140

3241

3342
@dataclass
@@ -117,11 +126,15 @@ class ActivityModelInput(TypedDict, total=False):
117126

118127

119128
class ModelActivity:
120-
"""Class wrapper for model invocation activities to allow model customization."""
129+
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
130+
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
131+
"""
121132

122133
def __init__(self, model_provider: Optional[ModelProvider] = None):
123134
"""Initialize the activity with a model provider."""
124-
self._model_provider = model_provider or MultiProvider()
135+
self._model_provider = model_provider or OpenAIProvider(
136+
openai_client=AsyncOpenAI(max_retries=0)
137+
)
125138

126139
@activity.defn
127140
@_auto_heartbeater
@@ -160,7 +173,7 @@ def make_tool(tool: ToolInput) -> Tool:
160173
raise UserError(f"Unknown tool type: {tool.name}")
161174

162175
tools = [make_tool(x) for x in input.get("tools", [])]
163-
handoffs = [
176+
handoffs: list[Handoff[Any, Any]] = [
164177
Handoff(
165178
tool_name=x.tool_name,
166179
tool_description=x.tool_description,
@@ -171,14 +184,51 @@ def make_tool(tool: ToolInput) -> Tool:
171184
)
172185
for x in input.get("handoffs", [])
173186
]
174-
return await model.get_response(
175-
system_instructions=input.get("system_instructions"),
176-
input=input_input,
177-
model_settings=input["model_settings"],
178-
tools=tools,
179-
output_schema=input.get("output_schema"),
180-
handoffs=handoffs,
181-
tracing=ModelTracing(input["tracing"]),
182-
previous_response_id=input.get("previous_response_id"),
183-
prompt=input.get("prompt"),
184-
)
187+
188+
try:
189+
return await model.get_response(
190+
system_instructions=input.get("system_instructions"),
191+
input=input_input,
192+
model_settings=input["model_settings"],
193+
tools=tools,
194+
output_schema=input.get("output_schema"),
195+
handoffs=handoffs,
196+
tracing=ModelTracing(input["tracing"]),
197+
previous_response_id=input.get("previous_response_id"),
198+
prompt=input.get("prompt"),
199+
)
200+
except APIStatusError as e:
201+
# Listen to server hints
202+
retry_after = None
203+
retry_after_ms_header = e.response.headers.get("retry-after-ms")
204+
if retry_after_ms_header is not None:
205+
retry_after = timedelta(milliseconds=float(retry_after_ms_header))
206+
207+
if retry_after is None:
208+
retry_after_header = e.response.headers.get("retry-after")
209+
if retry_after_header is not None:
210+
retry_after = timedelta(seconds=float(retry_after_header))
211+
212+
should_retry_header = e.response.headers.get("x-should-retry")
213+
if should_retry_header == "true":
214+
raise e
215+
if should_retry_header == "false":
216+
raise ApplicationError(
217+
"Non retryable OpenAI error",
218+
non_retryable=True,
219+
next_retry_delay=retry_after,
220+
) from e
221+
222+
# Specifically retryable status codes
223+
if e.response.status_code in [408, 409, 429, 500]:
224+
raise ApplicationError(
225+
"Retryable OpenAI status code",
226+
non_retryable=False,
227+
next_retry_delay=retry_after,
228+
) from e
229+
230+
raise ApplicationError(
231+
"Non retryable OpenAI status code",
232+
non_retryable=True,
233+
next_retry_delay=retry_after,
234+
) from e

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class ModelActivityParameters:
2020
task_queue: Optional[str] = None
2121
"""Specific task queue to use for model activities."""
2222

23-
schedule_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
23+
schedule_to_close_timeout: Optional[timedelta] = None
2424
"""Maximum time from scheduling to completion."""
2525

2626
schedule_to_start_timeout: Optional[timedelta] = None
2727
"""Maximum time from scheduling to starting."""
2828

29-
start_to_close_timeout: Optional[timedelta] = None
29+
start_to_close_timeout: Optional[timedelta] = timedelta(seconds=60)
3030
"""Maximum time for the activity to complete."""
3131

3232
heartbeat_timeout: Optional[timedelta] = None

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
1+
import typing
12
from dataclasses import replace
2-
from datetime import timedelta
3-
from typing import Optional, Union
3+
from typing import Any, Union
44

55
from agents import (
66
Agent,
77
RunConfig,
8-
RunHooks,
98
RunResult,
109
RunResultStreaming,
1110
TContext,
11+
Tool,
1212
TResponseInputItem,
1313
)
1414
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
1515

1616
from temporalio import workflow
17-
from temporalio.common import Priority, RetryPolicy
1817
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1918
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
20-
from temporalio.workflow import ActivityCancellationType, VersioningIntent
2119

2220

2321
class TemporalOpenAIRunner(AgentRunner):
@@ -36,7 +34,7 @@ async def run(
3634
self,
3735
starting_agent: Agent[TContext],
3836
input: Union[str, list[TResponseInputItem]],
39-
**kwargs,
37+
**kwargs: Any,
4038
) -> RunResult:
4139
"""Run the agent in a Temporal workflow."""
4240
if not workflow.in_workflow():
@@ -46,6 +44,13 @@ async def run(
4644
**kwargs,
4745
)
4846

47+
tool_types = typing.get_args(Tool)
48+
for t in starting_agent.tools:
49+
if not isinstance(t, tool_types):
50+
raise ValueError(
51+
"Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool."
52+
)
53+
4954
context = kwargs.get("context")
5055
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
5156
hooks = kwargs.get("hooks")
@@ -67,22 +72,21 @@ async def run(
6772
),
6873
)
6974

70-
with workflow.unsafe.imports_passed_through():
71-
return await self._runner.run(
72-
starting_agent=starting_agent,
73-
input=input,
74-
context=context,
75-
max_turns=max_turns,
76-
hooks=hooks,
77-
run_config=updated_run_config,
78-
previous_response_id=previous_response_id,
79-
)
75+
return await self._runner.run(
76+
starting_agent=starting_agent,
77+
input=input,
78+
context=context,
79+
max_turns=max_turns,
80+
hooks=hooks,
81+
run_config=updated_run_config,
82+
previous_response_id=previous_response_id,
83+
)
8084

8185
def run_sync(
8286
self,
8387
starting_agent: Agent[TContext],
8488
input: Union[str, list[TResponseInputItem]],
85-
**kwargs,
89+
**kwargs: Any,
8690
) -> RunResult:
8791
"""Run the agent synchronously (not supported in Temporal workflows)."""
8892
if not workflow.in_workflow():
@@ -97,7 +101,7 @@ def run_streamed(
97101
self,
98102
starting_agent: Agent[TContext],
99103
input: Union[str, list[TResponseInputItem]],
100-
**kwargs,
104+
**kwargs: Any,
101105
) -> RunResultStreaming:
102106
"""Run the agent with streaming responses (not supported in Temporal workflows)."""
103107
if not workflow.in_workflow():

0 commit comments

Comments
 (0)