Skip to content

Commit fd36c9c

Browse files
cleanup diff
1 parent 8971462 commit fd36c9c

File tree

3 files changed

+60
-59
lines changed

3 files changed

+60
-59
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,24 @@
1313

1414
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
1515
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
16-
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
16+
from temporalio.contrib.openai_agents._openai_runner import (
17+
TemporalOpenAIRunner,
18+
)
1719
from temporalio.contrib.openai_agents._temporal_trace_provider import (
1820
TemporalTraceProvider,
1921
)
2022
from temporalio.contrib.openai_agents._trace_interceptor import (
2123
OpenAIAgentsTracingInterceptor,
2224
)
2325
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
24-
from temporalio.contrib.pydantic import PydanticPayloadConverter, ToJsonOptions
25-
from temporalio.converter import DataConverter, DefaultPayloadConverter
26+
from temporalio.contrib.pydantic import (
27+
PydanticPayloadConverter,
28+
ToJsonOptions,
29+
)
30+
from temporalio.converter import (
31+
DataConverter,
32+
DefaultPayloadConverter,
33+
)
2634
from temporalio.plugin import SimplePlugin
2735
from temporalio.worker import WorkflowRunner
2836
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

tests/contrib/openai_agents/test_openai.py

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,20 @@
6060
HandoffOutputItem,
6161
ToolCallItem,
6262
ToolCallOutputItem,
63+
TResponseOutputItem,
6364
TResponseStreamEvent,
6465
)
6566
from openai import APIStatusError, AsyncOpenAI, BaseModel
6667
from openai.types.responses import (
6768
EasyInputMessageParam,
6869
ResponseCodeInterpreterToolCall,
6970
ResponseFileSearchToolCall,
71+
ResponseFunctionToolCall,
7072
ResponseFunctionToolCallParam,
7173
ResponseFunctionWebSearch,
7274
ResponseInputTextParam,
75+
ResponseOutputMessage,
76+
ResponseOutputText,
7377
)
7478
from openai.types.responses.response_file_search_tool_call import Result
7579
from openai.types.responses.response_function_web_search import ActionSearch
@@ -86,24 +90,23 @@
8690
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
8791
from temporalio.common import RetryPolicy
8892
from temporalio.contrib import openai_agents
89-
from temporalio.contrib.openai_agents import ModelActivityParameters
93+
from temporalio.contrib.openai_agents import (
94+
ModelActivityParameters,
95+
)
96+
from temporalio.contrib.openai_agents.testing import TestModel, TestModelProvider, ResponseBuilders, StaticTestModel
9097
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
9198
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
9299
from temporalio.contrib.openai_agents._temporal_model_stub import (
93100
_extract_summary,
94101
_TemporalModelStub,
95102
)
96-
from temporalio.contrib.openai_agents.testing import (
97-
ResponseBuilders,
98-
StaticTestModel,
99-
TestModel,
100-
TestModelProvider,
101-
)
102103
from temporalio.contrib.pydantic import pydantic_data_converter
103104
from temporalio.exceptions import ApplicationError, CancelledError, TemporalError
104105
from temporalio.testing import WorkflowEnvironment
105106
from temporalio.workflow import ActivityConfig
106-
from tests.contrib.openai_agents.research_agents.research_manager import ResearchManager
107+
from tests.contrib.openai_agents.research_agents.research_manager import (
108+
ResearchManager,
109+
)
107110
from tests.helpers import assert_eventually, new_worker
108111
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
109112

@@ -134,9 +137,9 @@ async def test_hello_world_agent(client: Client, use_local_model: bool):
134137
model_params=ModelActivityParameters(
135138
start_to_close_timeout=timedelta(seconds=30)
136139
),
137-
model_provider=(
138-
TestModelProvider(TestHelloModel()) if use_local_model else None
139-
),
140+
model_provider=TestModelProvider(TestHelloModel())
141+
if use_local_model
142+
else None,
140143
)
141144
]
142145
client = Client(**new_config)
@@ -316,9 +319,9 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
316319
model_params=ModelActivityParameters(
317320
start_to_close_timeout=timedelta(seconds=30)
318321
),
319-
model_provider=(
320-
TestModelProvider(TestWeatherModel()) if use_local_model else None
321-
),
322+
model_provider=TestModelProvider(TestWeatherModel())
323+
if use_local_model
324+
else None,
322325
)
323326
]
324327
client = Client(**new_config)
@@ -483,9 +486,9 @@ async def test_nexus_tool_workflow(
483486
model_params=ModelActivityParameters(
484487
start_to_close_timeout=timedelta(seconds=30)
485488
),
486-
model_provider=(
487-
TestModelProvider(TestNexusWeatherModel()) if use_local_model else None
488-
),
489+
model_provider=TestModelProvider(TestNexusWeatherModel())
490+
if use_local_model
491+
else None,
489492
)
490493
]
491494
client = Client(**new_config)
@@ -586,9 +589,9 @@ async def test_research_workflow(client: Client, use_local_model: bool):
586589
start_to_close_timeout=timedelta(seconds=120),
587590
schedule_to_close_timeout=timedelta(seconds=120),
588591
),
589-
model_provider=(
590-
TestModelProvider(TestResearchModel()) if use_local_model else None
591-
),
592+
model_provider=TestModelProvider(TestResearchModel())
593+
if use_local_model
594+
else None,
592595
)
593596
]
594597
client = Client(**new_config)
@@ -737,9 +740,9 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool):
737740
model_params=ModelActivityParameters(
738741
start_to_close_timeout=timedelta(seconds=30)
739742
),
740-
model_provider=(
741-
TestModelProvider(AgentAsToolsModel()) if use_local_model else None
742-
),
743+
model_provider=TestModelProvider(AgentAsToolsModel())
744+
if use_local_model
745+
else None,
743746
)
744747
]
745748
client = Client(**new_config)
@@ -1001,9 +1004,9 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10011004
model_params=ModelActivityParameters(
10021005
start_to_close_timeout=timedelta(seconds=30)
10031006
),
1004-
model_provider=(
1005-
TestModelProvider(CustomerServiceModel()) if use_local_model else None
1006-
),
1007+
model_provider=TestModelProvider(CustomerServiceModel())
1008+
if use_local_model
1009+
else None,
10071010
)
10081011
]
10091012
client = Client(**new_config)
@@ -1231,15 +1234,11 @@ async def test_input_guardrail(client: Client, use_local_model: bool):
12311234
model_params=ModelActivityParameters(
12321235
start_to_close_timeout=timedelta(seconds=30)
12331236
),
1234-
model_provider=(
1235-
TestModelProvider(
1236-
InputGuardrailModel(
1237-
"", openai_client=AsyncOpenAI(api_key="Fake key")
1238-
)
1239-
)
1240-
if use_local_model
1241-
else None
1242-
),
1237+
model_provider=TestModelProvider(
1238+
InputGuardrailModel("", openai_client=AsyncOpenAI(api_key="Fake key"))
1239+
)
1240+
if use_local_model
1241+
else None,
12431242
)
12441243
]
12451244
client = Client(**new_config)
@@ -1334,9 +1333,9 @@ async def test_output_guardrail(client: Client, use_local_model: bool):
13341333
model_params=ModelActivityParameters(
13351334
start_to_close_timeout=timedelta(seconds=30)
13361335
),
1337-
model_provider=(
1338-
TestModelProvider(OutputGuardrailModel()) if use_local_model else None
1339-
),
1336+
model_provider=TestModelProvider(OutputGuardrailModel())
1337+
if use_local_model
1338+
else None,
13401339
)
13411340
]
13421341
client = Client(**new_config)
@@ -1802,9 +1801,9 @@ async def test_file_search_tool(client: Client, use_local_model):
18021801
model_params=ModelActivityParameters(
18031802
start_to_close_timeout=timedelta(seconds=30)
18041803
),
1805-
model_provider=(
1806-
TestModelProvider(FileSearchToolModel()) if use_local_model else None
1807-
),
1804+
model_provider=TestModelProvider(FileSearchToolModel())
1805+
if use_local_model
1806+
else None,
18081807
)
18091808
]
18101809
client = Client(**new_config)
@@ -1878,9 +1877,9 @@ async def test_image_generation_tool(client: Client, use_local_model):
18781877
model_params=ModelActivityParameters(
18791878
start_to_close_timeout=timedelta(seconds=30)
18801879
),
1881-
model_provider=(
1882-
TestModelProvider(ImageGenerationModel()) if use_local_model else None
1883-
),
1880+
model_provider=TestModelProvider(ImageGenerationModel())
1881+
if use_local_model
1882+
else None,
18841883
)
18851884
]
18861885
client = Client(**new_config)
@@ -2419,9 +2418,9 @@ async def test_mcp_server(
24192418
model_params=ModelActivityParameters(
24202419
start_to_close_timeout=timedelta(seconds=120)
24212420
),
2422-
model_provider=(
2423-
TestModelProvider(TrackingMCPModel()) if use_local_model else None
2424-
),
2421+
model_provider=TestModelProvider(TrackingMCPModel())
2422+
if use_local_model
2423+
else None,
24252424
mcp_server_providers=[server],
24262425
)
24272426
]
@@ -2633,11 +2632,3 @@ async def test_model_conversion_loops():
26332632
triage_agent = seat_booking_agent.handoffs[0]
26342633
assert isinstance(triage_agent, Agent)
26352634
assert isinstance(triage_agent.model, _TemporalModelStub)
2636-
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
2637-
triage_agent = seat_booking_agent.handoffs[0]
2638-
assert isinstance(triage_agent, Agent)
2639-
assert isinstance(triage_agent.model, _TemporalModelStub)
2640-
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
2641-
triage_agent = seat_booking_agent.handoffs[0]
2642-
assert isinstance(triage_agent, Agent)
2643-
assert isinstance(triage_agent.model, _TemporalModelStub)

tests/contrib/openai_agents/test_openai_tracing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from temporalio.client import Client
99
from temporalio.contrib import openai_agents
10-
from temporalio.contrib.openai_agents.testing import TestModelProvider
10+
from temporalio.contrib.openai_agents.testing import (
11+
TestModelProvider,
12+
)
1113
from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel
1214
from tests.helpers import new_worker
1315

0 commit comments

Comments
 (0)