|
60 | 60 | HandoffOutputItem, |
61 | 61 | ToolCallItem, |
62 | 62 | ToolCallOutputItem, |
| 63 | + TResponseOutputItem, |
63 | 64 | TResponseStreamEvent, |
64 | 65 | ) |
65 | 66 | from openai import APIStatusError, AsyncOpenAI, BaseModel |
66 | 67 | from openai.types.responses import ( |
67 | 68 | EasyInputMessageParam, |
68 | 69 | ResponseCodeInterpreterToolCall, |
69 | 70 | ResponseFileSearchToolCall, |
| 71 | + ResponseFunctionToolCall, |
70 | 72 | ResponseFunctionToolCallParam, |
71 | 73 | ResponseFunctionWebSearch, |
72 | 74 | ResponseInputTextParam, |
| 75 | + ResponseOutputMessage, |
| 76 | + ResponseOutputText, |
73 | 77 | ) |
74 | 78 | from openai.types.responses.response_file_search_tool_call import Result |
75 | 79 | from openai.types.responses.response_function_web_search import ActionSearch |
|
86 | 90 | from temporalio.client import Client, WorkflowFailureError, WorkflowHandle |
87 | 91 | from temporalio.common import RetryPolicy |
88 | 92 | from temporalio.contrib import openai_agents |
89 | | -from temporalio.contrib.openai_agents import ModelActivityParameters |
| 93 | +from temporalio.contrib.openai_agents import ( |
| 94 | + ModelActivityParameters, |
| 95 | +) |
| 96 | +from temporalio.contrib.openai_agents.testing import TestModel, TestModelProvider, ResponseBuilders, StaticTestModel |
90 | 97 | from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider |
91 | 98 | from temporalio.contrib.openai_agents._openai_runner import _convert_agent |
92 | 99 | from temporalio.contrib.openai_agents._temporal_model_stub import ( |
93 | 100 | _extract_summary, |
94 | 101 | _TemporalModelStub, |
95 | 102 | ) |
96 | | -from temporalio.contrib.openai_agents.testing import ( |
97 | | - ResponseBuilders, |
98 | | - StaticTestModel, |
99 | | - TestModel, |
100 | | - TestModelProvider, |
101 | | -) |
102 | 103 | from temporalio.contrib.pydantic import pydantic_data_converter |
103 | 104 | from temporalio.exceptions import ApplicationError, CancelledError, TemporalError |
104 | 105 | from temporalio.testing import WorkflowEnvironment |
105 | 106 | from temporalio.workflow import ActivityConfig |
106 | | -from tests.contrib.openai_agents.research_agents.research_manager import ResearchManager |
| 107 | +from tests.contrib.openai_agents.research_agents.research_manager import ( |
| 108 | + ResearchManager, |
| 109 | +) |
107 | 110 | from tests.helpers import assert_eventually, new_worker |
108 | 111 | from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name |
109 | 112 |
|
@@ -134,9 +137,9 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): |
134 | 137 | model_params=ModelActivityParameters( |
135 | 138 | start_to_close_timeout=timedelta(seconds=30) |
136 | 139 | ), |
137 | | - model_provider=( |
138 | | - TestModelProvider(TestHelloModel()) if use_local_model else None |
139 | | - ), |
| 140 | + model_provider=TestModelProvider(TestHelloModel()) |
| 141 | + if use_local_model |
| 142 | + else None, |
140 | 143 | ) |
141 | 144 | ] |
142 | 145 | client = Client(**new_config) |
@@ -316,9 +319,9 @@ async def test_tool_workflow(client: Client, use_local_model: bool): |
316 | 319 | model_params=ModelActivityParameters( |
317 | 320 | start_to_close_timeout=timedelta(seconds=30) |
318 | 321 | ), |
319 | | - model_provider=( |
320 | | - TestModelProvider(TestWeatherModel()) if use_local_model else None |
321 | | - ), |
| 322 | + model_provider=TestModelProvider(TestWeatherModel()) |
| 323 | + if use_local_model |
| 324 | + else None, |
322 | 325 | ) |
323 | 326 | ] |
324 | 327 | client = Client(**new_config) |
@@ -483,9 +486,9 @@ async def test_nexus_tool_workflow( |
483 | 486 | model_params=ModelActivityParameters( |
484 | 487 | start_to_close_timeout=timedelta(seconds=30) |
485 | 488 | ), |
486 | | - model_provider=( |
487 | | - TestModelProvider(TestNexusWeatherModel()) if use_local_model else None |
488 | | - ), |
| 489 | + model_provider=TestModelProvider(TestNexusWeatherModel()) |
| 490 | + if use_local_model |
| 491 | + else None, |
489 | 492 | ) |
490 | 493 | ] |
491 | 494 | client = Client(**new_config) |
@@ -586,9 +589,9 @@ async def test_research_workflow(client: Client, use_local_model: bool): |
586 | 589 | start_to_close_timeout=timedelta(seconds=120), |
587 | 590 | schedule_to_close_timeout=timedelta(seconds=120), |
588 | 591 | ), |
589 | | - model_provider=( |
590 | | - TestModelProvider(TestResearchModel()) if use_local_model else None |
591 | | - ), |
| 592 | + model_provider=TestModelProvider(TestResearchModel()) |
| 593 | + if use_local_model |
| 594 | + else None, |
592 | 595 | ) |
593 | 596 | ] |
594 | 597 | client = Client(**new_config) |
@@ -737,9 +740,9 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): |
737 | 740 | model_params=ModelActivityParameters( |
738 | 741 | start_to_close_timeout=timedelta(seconds=30) |
739 | 742 | ), |
740 | | - model_provider=( |
741 | | - TestModelProvider(AgentAsToolsModel()) if use_local_model else None |
742 | | - ), |
| 743 | + model_provider=TestModelProvider(AgentAsToolsModel()) |
| 744 | + if use_local_model |
| 745 | + else None, |
743 | 746 | ) |
744 | 747 | ] |
745 | 748 | client = Client(**new_config) |
@@ -1001,9 +1004,9 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): |
1001 | 1004 | model_params=ModelActivityParameters( |
1002 | 1005 | start_to_close_timeout=timedelta(seconds=30) |
1003 | 1006 | ), |
1004 | | - model_provider=( |
1005 | | - TestModelProvider(CustomerServiceModel()) if use_local_model else None |
1006 | | - ), |
| 1007 | + model_provider=TestModelProvider(CustomerServiceModel()) |
| 1008 | + if use_local_model |
| 1009 | + else None, |
1007 | 1010 | ) |
1008 | 1011 | ] |
1009 | 1012 | client = Client(**new_config) |
@@ -1231,15 +1234,11 @@ async def test_input_guardrail(client: Client, use_local_model: bool): |
1231 | 1234 | model_params=ModelActivityParameters( |
1232 | 1235 | start_to_close_timeout=timedelta(seconds=30) |
1233 | 1236 | ), |
1234 | | - model_provider=( |
1235 | | - TestModelProvider( |
1236 | | - InputGuardrailModel( |
1237 | | - "", openai_client=AsyncOpenAI(api_key="Fake key") |
1238 | | - ) |
1239 | | - ) |
1240 | | - if use_local_model |
1241 | | - else None |
1242 | | - ), |
| 1237 | + model_provider=TestModelProvider( |
| 1238 | + InputGuardrailModel("", openai_client=AsyncOpenAI(api_key="Fake key")) |
| 1239 | + ) |
| 1240 | + if use_local_model |
| 1241 | + else None, |
1243 | 1242 | ) |
1244 | 1243 | ] |
1245 | 1244 | client = Client(**new_config) |
@@ -1334,9 +1333,9 @@ async def test_output_guardrail(client: Client, use_local_model: bool): |
1334 | 1333 | model_params=ModelActivityParameters( |
1335 | 1334 | start_to_close_timeout=timedelta(seconds=30) |
1336 | 1335 | ), |
1337 | | - model_provider=( |
1338 | | - TestModelProvider(OutputGuardrailModel()) if use_local_model else None |
1339 | | - ), |
| 1336 | + model_provider=TestModelProvider(OutputGuardrailModel()) |
| 1337 | + if use_local_model |
| 1338 | + else None, |
1340 | 1339 | ) |
1341 | 1340 | ] |
1342 | 1341 | client = Client(**new_config) |
@@ -1802,9 +1801,9 @@ async def test_file_search_tool(client: Client, use_local_model): |
1802 | 1801 | model_params=ModelActivityParameters( |
1803 | 1802 | start_to_close_timeout=timedelta(seconds=30) |
1804 | 1803 | ), |
1805 | | - model_provider=( |
1806 | | - TestModelProvider(FileSearchToolModel()) if use_local_model else None |
1807 | | - ), |
| 1804 | + model_provider=TestModelProvider(FileSearchToolModel()) |
| 1805 | + if use_local_model |
| 1806 | + else None, |
1808 | 1807 | ) |
1809 | 1808 | ] |
1810 | 1809 | client = Client(**new_config) |
@@ -1878,9 +1877,9 @@ async def test_image_generation_tool(client: Client, use_local_model): |
1878 | 1877 | model_params=ModelActivityParameters( |
1879 | 1878 | start_to_close_timeout=timedelta(seconds=30) |
1880 | 1879 | ), |
1881 | | - model_provider=( |
1882 | | - TestModelProvider(ImageGenerationModel()) if use_local_model else None |
1883 | | - ), |
| 1880 | + model_provider=TestModelProvider(ImageGenerationModel()) |
| 1881 | + if use_local_model |
| 1882 | + else None, |
1884 | 1883 | ) |
1885 | 1884 | ] |
1886 | 1885 | client = Client(**new_config) |
@@ -2419,9 +2418,9 @@ async def test_mcp_server( |
2419 | 2418 | model_params=ModelActivityParameters( |
2420 | 2419 | start_to_close_timeout=timedelta(seconds=120) |
2421 | 2420 | ), |
2422 | | - model_provider=( |
2423 | | - TestModelProvider(TrackingMCPModel()) if use_local_model else None |
2424 | | - ), |
| 2421 | + model_provider=TestModelProvider(TrackingMCPModel()) |
| 2422 | + if use_local_model |
| 2423 | + else None, |
2425 | 2424 | mcp_server_providers=[server], |
2426 | 2425 | ) |
2427 | 2426 | ] |
@@ -2633,11 +2632,3 @@ async def test_model_conversion_loops(): |
2633 | 2632 | triage_agent = seat_booking_agent.handoffs[0] |
2634 | 2633 | assert isinstance(triage_agent, Agent) |
2635 | 2634 | assert isinstance(triage_agent.model, _TemporalModelStub) |
2636 | | - seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "") |
2637 | | - triage_agent = seat_booking_agent.handoffs[0] |
2638 | | - assert isinstance(triage_agent, Agent) |
2639 | | - assert isinstance(triage_agent.model, _TemporalModelStub) |
2640 | | - seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "") |
2641 | | - triage_agent = seat_booking_agent.handoffs[0] |
2642 | | - assert isinstance(triage_agent, Agent) |
2643 | | - assert isinstance(triage_agent.model, _TemporalModelStub) |
0 commit comments