Skip to content

Commit 6185d25

Browse files
authored
Merge branch 'main' into openai/trace_random
2 parents 08b54ed + beb9c9d commit 6185d25

File tree

9 files changed

+1699
-100
lines changed

9 files changed

+1699
-100
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dev = [
5757
"pytest-cov>=6.1.1",
5858
"httpx>=0.28.1",
5959
"pytest-pretty>=1.3.0",
60+
"openai-agents[litellm] >= 0.2.3,<0.3"
6061
]
6162

6263
[tool.poe.tasks]

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import json
88
from dataclasses import dataclass
99
from datetime import timedelta
10-
from typing import Any, Optional, Union, cast
10+
from typing import Any, Optional, Union
1111

1212
from agents import (
1313
AgentOutputSchemaBase,
14+
CodeInterpreterTool,
1415
FileSearchTool,
1516
FunctionTool,
1617
Handoff,
18+
HostedMCPTool,
19+
ImageGenerationTool,
1720
ModelProvider,
1821
ModelResponse,
1922
ModelSettings,
@@ -25,13 +28,12 @@
2528
UserError,
2629
WebSearchTool,
2730
)
28-
from agents.models.multi_provider import MultiProvider
2931
from openai import (
3032
APIStatusError,
3133
AsyncOpenAI,
32-
AuthenticationError,
33-
PermissionDeniedError,
3434
)
35+
from openai.types.responses.tool_param import Mcp
36+
from pydantic_core import to_json
3537
from typing_extensions import Required, TypedDict
3638

3739
from temporalio import activity
@@ -41,7 +43,9 @@
4143

4244
@dataclass
4345
class HandoffInput:
44-
"""Data conversion friendly representation of a Handoff."""
46+
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
47+
execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
48+
"""
4549

4650
tool_name: str
4751
tool_description: str
@@ -52,15 +56,33 @@ class HandoffInput:
5256

5357
@dataclass
5458
class FunctionToolInput:
55-
"""Data conversion friendly representation of a FunctionTool."""
59+
"""Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
60+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
61+
"""
5662

5763
name: str
5864
description: str
5965
params_json_schema: dict[str, Any]
6066
strict_json_schema: bool = True
6167

6268

63-
ToolInput = Union[FunctionToolInput, FileSearchTool, WebSearchTool]
69+
@dataclass
70+
class HostedMCPToolInput:
71+
"""Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
72+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
73+
"""
74+
75+
tool_config: Mcp
76+
77+
78+
ToolInput = Union[
79+
FunctionToolInput,
80+
FileSearchTool,
81+
WebSearchTool,
82+
ImageGenerationTool,
83+
CodeInterpreterTool,
84+
HostedMCPToolInput,
85+
]
6486

6587

6688
@dataclass
@@ -152,22 +174,31 @@ async def empty_on_invoke_handoff(
152174

153175
# workaround for https://github.com/pydantic/pydantic/issues/9541
154176
# ValidatorIterator returned
155-
input_json = json.dumps(input["input"], default=str)
177+
input_json = to_json(input["input"])
156178
input_input = json.loads(input_json)
157179

158180
def make_tool(tool: ToolInput) -> Tool:
159-
if isinstance(tool, FileSearchTool):
160-
return cast(FileSearchTool, tool)
161-
elif isinstance(tool, WebSearchTool):
162-
return cast(WebSearchTool, tool)
181+
if isinstance(
182+
tool,
183+
(
184+
FileSearchTool,
185+
WebSearchTool,
186+
ImageGenerationTool,
187+
CodeInterpreterTool,
188+
),
189+
):
190+
return tool
191+
elif isinstance(tool, HostedMCPToolInput):
192+
return HostedMCPTool(
193+
tool_config=tool.tool_config,
194+
)
163195
elif isinstance(tool, FunctionToolInput):
164-
t = cast(FunctionToolInput, tool)
165196
return FunctionTool(
166-
name=t.name,
167-
description=t.description,
168-
params_json_schema=t.params_json_schema,
197+
name=tool.name,
198+
description=tool.description,
199+
params_json_schema=tool.params_json_schema,
169200
on_invoke_tool=empty_on_invoke_tool,
170-
strict_json_schema=t.strict_json_schema,
201+
strict_json_schema=tool.strict_json_schema,
171202
)
172203
else:
173204
raise UserError(f"Unknown tool type: {tool.name}")

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
RunConfig,
88
RunResult,
99
RunResultStreaming,
10+
SQLiteSession,
1011
TContext,
1112
Tool,
1213
TResponseInputItem,
@@ -51,23 +52,33 @@ async def run(
5152
"Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool."
5253
)
5354

55+
if starting_agent.mcp_servers:
56+
raise ValueError(
57+
"Temporal OpenAI agent does not support on demand MCP servers."
58+
)
59+
5460
context = kwargs.get("context")
5561
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
5662
hooks = kwargs.get("hooks")
5763
run_config = kwargs.get("run_config")
5864
previous_response_id = kwargs.get("previous_response_id")
65+
session = kwargs.get("session")
66+
67+
if isinstance(session, SQLiteSession):
68+
raise ValueError("Temporal workflows don't support SQLite sessions.")
5969

6070
if run_config is None:
6171
run_config = RunConfig()
6272

63-
if run_config.model is not None and not isinstance(run_config.model, str):
73+
model_name = run_config.model or starting_agent.model
74+
if model_name is not None and not isinstance(model_name, str):
6475
raise ValueError(
65-
"Temporal workflows require a model name to be a string in the run config."
76+
"Temporal workflows require a model name to be a string in the run config and/or agent."
6677
)
6778
updated_run_config = replace(
6879
run_config,
6980
model=_TemporalModelStub(
70-
run_config.model,
81+
model_name=model_name,
7182
model_params=self.model_params,
7283
),
7384
)
@@ -80,6 +91,7 @@ async def run(
8091
hooks=hooks,
8192
run_config=updated_run_config,
8293
previous_response_id=previous_response_id,
94+
session=session,
8395
)
8496

8597
def run_sync(

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88

99
logger = logging.getLogger(__name__)
1010

11-
from typing import Any, AsyncIterator, Sequence, Union, cast
11+
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
1414
AgentOutputSchema,
1515
AgentOutputSchemaBase,
16-
ComputerTool,
16+
CodeInterpreterTool,
1717
FileSearchTool,
1818
FunctionTool,
1919
Handoff,
20+
HostedMCPTool,
21+
ImageGenerationTool,
2022
Model,
2123
ModelResponse,
2224
ModelSettings,
@@ -33,6 +35,7 @@
3335
AgentOutputSchemaInput,
3436
FunctionToolInput,
3537
HandoffInput,
38+
HostedMCPToolInput,
3639
ModelActivity,
3740
ModelTracingInput,
3841
ToolInput,
@@ -54,7 +57,7 @@ def __init__(
5457
async def get_response(
5558
self,
5659
system_instructions: Optional[str],
57-
input: Union[str, list[TResponseInputItem], dict[str, str]],
60+
input: Union[str, list[TResponseInputItem]],
5861
model_settings: ModelSettings,
5962
tools: list[Tool],
6063
output_schema: Optional[AgentOutputSchemaBase],
@@ -64,35 +67,19 @@ async def get_response(
6467
previous_response_id: Optional[str],
6568
prompt: Optional[ResponsePromptParam],
6669
) -> ModelResponse:
67-
def get_summary(
68-
input: Union[str, list[TResponseInputItem], dict[str, str]],
69-
) -> str:
70-
### Activity summary shown in the UI
71-
try:
72-
max_size = 100
73-
if isinstance(input, str):
74-
return input[:max_size]
75-
elif isinstance(input, list):
76-
seq_input = cast(Sequence[Any], input)
77-
last_item = seq_input[-1]
78-
if isinstance(last_item, dict):
79-
return last_item.get("content", "")[:max_size]
80-
elif hasattr(last_item, "content"):
81-
return str(getattr(last_item, "content"))[:max_size]
82-
return str(last_item)[:max_size]
83-
elif isinstance(input, dict):
84-
return input.get("content", "")[:max_size]
85-
except Exception as e:
86-
logger.error(f"Error getting summary: {e}")
87-
return ""
88-
8970
def make_tool_info(tool: Tool) -> ToolInput:
90-
if isinstance(tool, (FileSearchTool, WebSearchTool)):
71+
if isinstance(
72+
tool,
73+
(
74+
FileSearchTool,
75+
WebSearchTool,
76+
ImageGenerationTool,
77+
CodeInterpreterTool,
78+
),
79+
):
9180
return tool
92-
elif isinstance(tool, ComputerTool):
93-
raise NotImplementedError(
94-
"Computer search preview is not supported in Temporal model"
95-
)
81+
elif isinstance(tool, HostedMCPTool):
82+
return HostedMCPToolInput(tool_config=tool.tool_config)
9683
elif isinstance(tool, FunctionTool):
9784
return FunctionToolInput(
9885
name=tool.name,
@@ -101,7 +88,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
10188
strict_json_schema=tool.strict_json_schema,
10289
)
10390
else:
104-
raise ValueError(f"Unknown tool type: {tool.name}")
91+
raise ValueError(f"Unsupported tool type: {tool.name}")
10592

10693
tool_infos = [make_tool_info(x) for x in tools]
10794
handoff_infos = [
@@ -150,7 +137,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
150137
return await workflow.execute_activity_method(
151138
ModelActivity.invoke_model_activity,
152139
activity_input,
153-
summary=self.model_params.summary_override or get_summary(input),
140+
summary=self.model_params.summary_override or _extract_summary(input),
154141
task_queue=self.model_params.task_queue,
155142
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
156143
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
@@ -176,3 +163,34 @@ def stream_response(
176163
prompt: ResponsePromptParam | None,
177164
) -> AsyncIterator[TResponseStreamEvent]:
178165
raise NotImplementedError("Temporal model doesn't support streams yet")
166+
167+
168+
def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
169+
### Activity summary shown in the UI
170+
try:
171+
max_size = 100
172+
if isinstance(input, str):
173+
return input[:max_size]
174+
elif isinstance(input, list):
175+
# Find all message inputs, which are reasonably summarizable
176+
messages: list[TResponseInputItem] = [
177+
item for item in input if item.get("type", "message") == "message"
178+
]
179+
if not messages:
180+
return ""
181+
182+
content: Any = messages[-1].get("content", "")
183+
184+
# In the case of multiple contents, take the last one
185+
if isinstance(content, list):
186+
if not content:
187+
return ""
188+
content = content[-1]
189+
190+
# Take the text field from the content if present
191+
if isinstance(content, dict) and content.get("text") is not None:
192+
content = content.get("text")
193+
return str(content)[:max_size]
194+
except Exception as e:
195+
logger.error(f"Error getting summary: {e}")
196+
return ""

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
from temporalio.contrib.openai_agents._trace_interceptor import (
3535
OpenAIAgentsTracingInterceptor,
3636
)
37-
from temporalio.contrib.pydantic import pydantic_data_converter
37+
from temporalio.contrib.pydantic import (
38+
PydanticPayloadConverter,
39+
ToJsonOptions,
40+
)
41+
from temporalio.converter import (
42+
DataConverter,
43+
)
3844
from temporalio.worker import Worker, WorkerConfig
3945

4046

@@ -137,6 +143,11 @@ def stream_response(
137143
raise NotImplementedError()
138144

139145

146+
class _OpenAIPayloadConverter(PydanticPayloadConverter):
147+
def __init__(self) -> None:
148+
super().__init__(ToJsonOptions(exclude_unset=True))
149+
150+
140151
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
141152
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
142153
@@ -232,7 +243,9 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
232243
Returns:
233244
The modified client configuration.
234245
"""
235-
config["data_converter"] = pydantic_data_converter
246+
config["data_converter"] = DataConverter(
247+
payload_converter_class=_OpenAIPayloadConverter
248+
)
236249
return super().configure_client(config)
237250

238251
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:

temporalio/contrib/openai_agents/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
134134
cancellation_type=cancellation_type,
135135
activity_id=activity_id,
136136
versioning_intent=versioning_intent,
137-
summary=summary,
137+
summary=summary or schema.description,
138138
priority=priority,
139139
)
140140
try:

0 commit comments

Comments
 (0)