Skip to content

Commit 19e57eb

Browse files
committed
Add local test
1 parent 8fff2aa commit 19e57eb

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def set_open_ai_agent_temporal_overrides(
104104
class TestModelProvider(ModelProvider):
105105
"""Test model provider which simply returns the given module."""
106106

107+
__test__ = False
108+
107109
def __init__(self, model: Model):
108110
"""Initialize a test model provider with a model."""
109111
self._model = model
@@ -116,6 +118,8 @@ def get_model(self, model_name: Union[str, None]) -> Model:
116118
class TestModel(Model):
117119
"""Test model for use mocking model responses."""
118120

121+
__test__ = False
122+
119123
def __init__(self, fn: Callable[[], ModelResponse]) -> None:
120124
"""Initialize a test model with a callable."""
121125
self.fn = fn

tests/contrib/openai_agents/test_openai.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from typing import (
99
Any,
1010
AsyncIterator,
11-
Callable,
1211
Optional,
13-
Sequence,
1412
Union,
1513
no_type_check,
1614
)
@@ -19,7 +17,6 @@
1917
import pytest
2018
from agents import (
2119
Agent,
22-
AgentBase,
2320
AgentOutputSchemaBase,
2421
CodeInterpreterTool,
2522
FileSearchTool,
@@ -29,7 +26,6 @@
2926
ImageGenerationTool,
3027
InputGuardrailTripwireTriggered,
3128
ItemHelpers,
32-
LocalShellTool,
3329
MCPToolApprovalFunctionResult,
3430
MCPToolApprovalRequest,
3531
MessageOutputItem,
@@ -58,6 +54,7 @@
5854
HandoffOutputItem,
5955
ToolCallItem,
6056
ToolCallOutputItem,
57+
TResponseOutputItem,
6158
TResponseStreamEvent,
6259
)
6360
from agents.mcp import MCPServer, MCPServerStdio
@@ -98,10 +95,6 @@
9895
from temporalio.contrib.pydantic import pydantic_data_converter
9996
from temporalio.exceptions import CancelledError
10097
from temporalio.testing import WorkflowEnvironment
101-
from temporalio.worker.workflow_sandbox import (
102-
SandboxedWorkflowRunner,
103-
SandboxRestrictions,
104-
)
10598
from tests.contrib.openai_agents.research_agents.research_manager import (
10699
ResearchManager,
107100
)
@@ -2530,7 +2523,69 @@ async def run(self, question: str) -> str:
25302523
return result.final_output
25312524

25322525

2533-
async def test_mcp_server(client: Client):
2526+
class ResponseBuilders:
2527+
@staticmethod
2528+
def model_response(output: TResponseOutputItem) -> ModelResponse:
2529+
return ModelResponse(
2530+
output=[output],
2531+
usage=Usage(),
2532+
response_id=None,
2533+
)
2534+
2535+
@staticmethod
2536+
def tool_call(arguments: str, name: str) -> ModelResponse:
2537+
return ResponseBuilders.model_response(
2538+
ResponseFunctionToolCall(
2539+
arguments=arguments,
2540+
call_id="call",
2541+
name=name,
2542+
type="function_call",
2543+
id="id",
2544+
status="completed",
2545+
)
2546+
)
2547+
2548+
@staticmethod
2549+
def output_message(text: str) -> ModelResponse:
2550+
return ResponseBuilders.model_response(
2551+
ResponseOutputMessage(
2552+
id="",
2553+
content=[
2554+
ResponseOutputText(
2555+
text=text,
2556+
annotations=[],
2557+
type="output_text",
2558+
)
2559+
],
2560+
role="assistant",
2561+
status="completed",
2562+
type="message",
2563+
)
2564+
)
2565+
2566+
2567+
class McpServerModel(StaticTestModel):
2568+
responses = [
2569+
ResponseBuilders.tool_call(
2570+
arguments='{"path":"/"}',
2571+
name="list_directory",
2572+
),
2573+
ResponseBuilders.tool_call(
2574+
arguments="{}",
2575+
name="list_allowed_directories",
2576+
),
2577+
ResponseBuilders.tool_call(
2578+
arguments='{"path":"."}',
2579+
name="list_directory",
2580+
),
2581+
ResponseBuilders.output_message(
2582+
"Here are the files and directories in the allowed path."
2583+
),
2584+
]
2585+
2586+
2587+
@pytest.mark.parametrize("use_local_model", [True, False])
2588+
async def test_mcp_server(client: Client, use_local_model: bool):
25342589
if not os.environ.get("OPENAI_API_KEY"):
25352590
pytest.skip("No openai API key")
25362591

@@ -2562,6 +2617,9 @@ async def test_mcp_server(client: Client):
25622617
model_params=ModelActivityParameters(
25632618
start_to_close_timeout=timedelta(seconds=120)
25642619
),
2620+
model_provider=TestModelProvider(McpServerModel())
2621+
if use_local_model
2622+
else None,
25652623
mcp_servers=[server, server2],
25662624
)
25672625
]
@@ -2579,5 +2637,5 @@ async def test_mcp_server(client: Client):
25792637
execution_timeout=timedelta(seconds=30),
25802638
)
25812639
result = await workflow_handle.result()
2582-
print(result)
2583-
assert False
2640+
if use_local_model:
2641+
assert result == "Here are the files and directories in the allowed path."

0 commit comments

Comments
 (0)