Skip to content

Commit 7457275

Browse files
committed
Adding a few more tool types and tests
1 parent b9842f4 commit 7457275

File tree

3 files changed

+224
-45
lines changed

3 files changed

+224
-45
lines changed

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
from agents import (
1212
AgentOutputSchemaBase,
13+
CodeInterpreterTool,
1314
FileSearchTool,
1415
FunctionTool,
1516
Handoff,
17+
HostedMCPTool,
18+
ImageGenerationTool,
1619
ModelProvider,
1720
ModelResponse,
1821
ModelSettings,
@@ -21,9 +24,11 @@
2124
Tool,
2225
TResponseInputItem,
2326
UserError,
24-
WebSearchTool, ImageGenerationTool, CodeInterpreterTool,
27+
WebSearchTool,
2528
)
2629
from agents.models.multi_provider import MultiProvider
30+
from openai.types.responses.tool_param import Mcp
31+
from pydantic_core import to_json, to_jsonable_python
2732
from typing_extensions import Required, TypedDict
2833

2934
from temporalio import activity
@@ -51,7 +56,21 @@ class FunctionToolInput:
5156
strict_json_schema: bool = True
5257

5358

54-
ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool]
59+
@dataclass
60+
class HostedMCPToolInput:
61+
"""Data conversion friendly representation of a HostedMCPTool."""
62+
63+
tool_config: Mcp
64+
65+
66+
ToolInput = Union[
67+
FunctionToolInput,
68+
FileSearchTool,
69+
WebSearchTool,
70+
ImageGenerationTool,
71+
CodeInterpreterTool,
72+
HostedMCPToolInput,
73+
]
5574

5675

5776
@dataclass
@@ -137,22 +156,28 @@ async def empty_on_invoke_handoff(
137156
) -> Any:
138157
return None
139158

140-
# workaround for https://github.com/pydantic/pydantic/issues/9541
141-
# ValidatorIterator returned
142-
input_json = json.dumps(input["input"], default=str)
143-
input_input = json.loads(input_json)
144-
145159
def make_tool(tool: ToolInput) -> Tool:
146-
if isinstance(tool, (FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool)):
160+
if isinstance(
161+
tool,
162+
(
163+
FileSearchTool,
164+
WebSearchTool,
165+
ImageGenerationTool,
166+
CodeInterpreterTool,
167+
),
168+
):
147169
return cast(Tool, tool)
170+
elif isinstance(tool, HostedMCPToolInput):
171+
return HostedMCPTool(
172+
tool_config=tool.tool_config,
173+
)
148174
elif isinstance(tool, FunctionToolInput):
149-
t = cast(FunctionToolInput, tool)
150175
return FunctionTool(
151-
name=t.name,
152-
description=t.description,
153-
params_json_schema=t.params_json_schema,
176+
name=tool.name,
177+
description=tool.description,
178+
params_json_schema=tool.params_json_schema,
154179
on_invoke_tool=empty_on_invoke_tool,
155-
strict_json_schema=t.strict_json_schema,
180+
strict_json_schema=tool.strict_json_schema,
156181
)
157182
else:
158183
raise UserError(f"Unknown tool type: {tool.name}")
@@ -171,7 +196,7 @@ def make_tool(tool: ToolInput) -> Tool:
171196
]
172197
return await model.get_response(
173198
system_instructions=input.get("system_instructions"),
174-
input=input_input,
199+
input=input["input"],
175200
model_settings=input["model_settings"],
176201
tools=tools,
177202
output_schema=input.get("output_schema"),

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,20 @@
1313
from agents import (
1414
AgentOutputSchema,
1515
AgentOutputSchemaBase,
16+
CodeInterpreterTool,
1617
ComputerTool,
1718
FileSearchTool,
1819
FunctionTool,
1920
Handoff,
21+
HostedMCPTool,
22+
ImageGenerationTool,
2023
Model,
2124
ModelResponse,
2225
ModelSettings,
2326
ModelTracing,
2427
Tool,
2528
TResponseInputItem,
26-
WebSearchTool, ImageGenerationTool, CodeInterpreterTool,
29+
WebSearchTool,
2730
)
2831
from agents.items import TResponseStreamEvent
2932
from openai.types.responses.response_prompt_param import ResponsePromptParam
@@ -33,6 +36,7 @@
3336
AgentOutputSchemaInput,
3437
FunctionToolInput,
3538
HandoffInput,
39+
HostedMCPToolInput,
3640
ModelActivity,
3741
ModelTracingInput,
3842
ToolInput,
@@ -87,8 +91,22 @@ def get_summary(
8791
return ""
8892

8993
def make_tool_info(tool: Tool) -> ToolInput:
90-
if isinstance(tool, (FileSearchTool, WebSearchTool, ImageGenerationTool, CodeInterpreterTool)):
94+
if isinstance(
95+
tool,
96+
(
97+
FileSearchTool,
98+
WebSearchTool,
99+
ImageGenerationTool,
100+
CodeInterpreterTool,
101+
),
102+
):
91103
return tool
104+
elif isinstance(tool, HostedMCPTool):
105+
# if tool.on_approval_request is not None:
106+
# raise ValueError(
107+
# "HostedMCPTool with approval functions not currently supported."
108+
# )
109+
return HostedMCPToolInput(tool_config=tool.tool_config)
92110
elif isinstance(tool, FunctionTool):
93111
return FunctionToolInput(
94112
name=tool.name,

0 commit comments

Comments
 (0)