Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dev = [
"pytest-cov>=6.1.1",
"httpx>=0.28.1",
"pytest-pretty>=1.3.0",
"openai-agents[litellm] >= 0.2.3,<0.3"
]

[tool.poe.tasks]
Expand Down
58 changes: 38 additions & 20 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
"""

import enum
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union, cast

from agents import (
AgentOutputSchemaBase,
CodeInterpreterTool,
FileSearchTool,
FunctionTool,
Handoff,
HostedMCPTool,
ImageGenerationTool,
ModelProvider,
ModelResponse,
ModelSettings,
Expand All @@ -25,13 +27,11 @@
UserError,
WebSearchTool,
)
from agents.models.multi_provider import MultiProvider
from openai import (
APIStatusError,
AsyncOpenAI,
AuthenticationError,
PermissionDeniedError,
)
from openai.types.responses.tool_param import Mcp
from typing_extensions import Required, TypedDict

from temporalio import activity
Expand Down Expand Up @@ -60,7 +60,21 @@ class FunctionToolInput:
strict_json_schema: bool = True


ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool]
@dataclass
class HostedMCPToolInput:
"""Data conversion friendly representation of a HostedMCPTool."""

tool_config: Mcp


ToolInput = Union[
FunctionToolInput,
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
HostedMCPToolInput,
]


@dataclass
Expand Down Expand Up @@ -150,24 +164,28 @@ async def empty_on_invoke_handoff(
) -> Any:
return None

# workaround for https://github.com/pydantic/pydantic/issues/9541
# ValidatorIterator returned
input_json = json.dumps(input["input"], default=str)
input_input = json.loads(input_json)

def make_tool(tool: ToolInput) -> Tool:
if isinstance(tool, FileSearchTool):
return cast(FileSearchTool, tool)
elif isinstance(tool, WebSearchTool):
return cast(WebSearchTool, tool)
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return cast(Tool, tool)
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
t = cast(FunctionToolInput, tool)
return FunctionTool(
name=t.name,
description=t.description,
params_json_schema=t.params_json_schema,
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=empty_on_invoke_tool,
strict_json_schema=t.strict_json_schema,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}")
Expand All @@ -188,7 +206,7 @@ def make_tool(tool: ToolInput) -> Tool:
try:
return await model.get_response(
system_instructions=input.get("system_instructions"),
input=input_input,
input=input["input"],
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
Expand Down
22 changes: 16 additions & 6 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
from agents import (
AgentOutputSchema,
AgentOutputSchemaBase,
CodeInterpreterTool,
ComputerTool,
FileSearchTool,
FunctionTool,
Handoff,
HostedMCPTool,
ImageGenerationTool,
Model,
ModelResponse,
ModelSettings,
Expand All @@ -33,6 +36,7 @@
AgentOutputSchemaInput,
FunctionToolInput,
HandoffInput,
HostedMCPToolInput,
ModelActivity,
ModelTracingInput,
ToolInput,
Expand Down Expand Up @@ -87,12 +91,18 @@ def get_summary(
return ""

def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, (FileSearchTool, WebSearchTool)):
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, ComputerTool):
raise NotImplementedError(
"Computer search preview is not supported in Temporal model"
)
elif isinstance(tool, HostedMCPTool):
return HostedMCPToolInput(tool_config=tool.tool_config)
elif isinstance(tool, FunctionTool):
return FunctionToolInput(
name=tool.name,
Expand All @@ -101,7 +111,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
strict_json_schema=tool.strict_json_schema,
)
else:
raise ValueError(f"Unknown tool type: {tool.name}")
raise ValueError(f"Unsupported tool type: {tool.name}")

tool_infos = [make_tool_info(x) for x in tools]
handoff_infos = [
Expand Down
Loading
Loading