Skip to content

Commit 2004d95

Browse files
committed
Only set timeouts for default model provider
1 parent 48569ad commit 2004d95

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import AsyncIterator, Callable, Optional, Union
56

67
from agents import (
@@ -39,7 +40,7 @@
3940

4041
@contextmanager
4142
def set_open_ai_agent_temporal_overrides(
42-
model_params: Optional[ModelActivityParameters] = None,
43+
model_params: ModelActivityParameters,
4344
auto_close_tracing_in_workflows: bool = False,
4445
):
4546
"""Configure Temporal-specific overrides for OpenAI agents.
@@ -69,14 +70,6 @@ def set_open_ai_agent_temporal_overrides(
6970
if model_params is None:
7071
model_params = ModelActivityParameters()
7172

72-
if (
73-
not model_params.start_to_close_timeout
74-
and not model_params.schedule_to_close_timeout
75-
):
76-
raise ValueError(
77-
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
78-
)
79-
8073
previous_runner = get_default_agent_runner()
8174
previous_trace_provider = get_trace_provider()
8275
provider = TemporalTraceProvider(
@@ -208,6 +201,22 @@ def __init__(
208201
model_provider: Optional model provider for custom model implementations.
209202
Useful for testing or custom model integrations.
210203
"""
204+
if model_params is None:
205+
model_params = ModelActivityParameters()
206+
207+
# For the default provider, we provide a default start_to_close_timeout of 60 seconds.
208+
# Other providers will need to define their own.
209+
if (
210+
model_params.start_to_close_timeout is None
211+
and model_params.schedule_to_close_timeout is None
212+
):
213+
if model_provider is None:
214+
model_params.start_to_close_timeout = timedelta(seconds=60)
215+
else:
216+
raise ValueError(
217+
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
218+
)
219+
211220
self._model_params = model_params
212221
self._model_provider = model_provider
213222

tests/contrib/openai_agents/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,7 @@ def status_error(status: int):
18001800
model_params=ModelActivityParameters(
18011801
retry_policy=RetryPolicy(maximum_attempts=2),
18021802
),
1803-
model_provider=TestModelProvider(TestModel(lambda: status_error(status)))
1803+
model_provider=TestModelProvider(TestModel(lambda: status_error(status))),
18041804
)
18051805
]
18061806
client = Client(**new_config)

tests/contrib/openai_agents/test_openai_replay.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from temporalio.client import WorkflowHistory
6+
from temporalio.contrib.openai_agents import ModelActivityParameters
67
from temporalio.contrib.openai_agents._temporal_openai_agents import (
78
set_open_ai_agent_temporal_overrides,
89
)
@@ -35,7 +36,7 @@ async def test_replay(file_name: str) -> None:
3536
with (Path(__file__).with_name("histories") / file_name).open("r") as f:
3637
history_json = f.read()
3738

38-
with set_open_ai_agent_temporal_overrides():
39+
with set_open_ai_agent_temporal_overrides(ModelActivityParameters()):
3940
await Replayer(
4041
workflows=[
4142
ResearchWorkflow,

0 commit comments

Comments
 (0)