Skip to content

Commit c197efd

Browse files
committed
Moving mcp classes into plugin
1 parent 28b974e commit c197efd

File tree

4 files changed

+204
-136
lines changed

4 files changed

+204
-136
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Use with caution in production environments.
99
"""
1010

11+
from temporalio.contrib.openai_agents._mcp import TemporalMCPServerWorkflowShim
1112
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1213
from temporalio.contrib.openai_agents._temporal_openai_agents import (
1314
OpenAIAgentsPlugin,
@@ -24,6 +25,7 @@
2425
"OpenAIAgentsPlugin",
2526
"ModelActivityParameters",
2627
"workflow",
28+
"TemporalMCPServerWorkflowShim",
2729
"TestModel",
2830
"TestModelProvider",
2931
]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from datetime import timedelta
2+
from typing import Any, Callable, Optional, Sequence
3+
4+
from agents import AgentBase, RunContextWrapper
5+
from agents.mcp import MCPServer
6+
from mcp import GetPromptResult, ListPromptsResult
7+
from mcp import Tool as MCPTool
8+
from mcp.types import CallToolResult
9+
10+
from temporalio import activity, workflow
11+
12+
13+
class TemporalMCPServerWorkflowShim(MCPServer):
14+
def __init__(self, name: str):
15+
self.server_name = name
16+
super().__init__()
17+
18+
@property
19+
def name(self) -> str:
20+
return self.server_name
21+
22+
async def connect(self) -> None:
23+
raise ValueError("Cannot connect to a server shim")
24+
25+
async def cleanup(self) -> None:
26+
raise ValueError("Cannot clean up a server shim")
27+
28+
async def list_tools(
29+
self,
30+
run_context: Optional[RunContextWrapper[Any]] = None,
31+
agent: Optional[AgentBase] = None,
32+
) -> list[MCPTool]:
33+
workflow.logger.info("Listing tools")
34+
tools: list[MCPTool] = await workflow.execute_local_activity(
35+
self.name + "-list-tools",
36+
start_to_close_timeout=timedelta(seconds=30),
37+
result_type=list[MCPTool],
38+
)
39+
print(tools[0])
40+
print("Tool type:", type(tools[0]))
41+
# print(type(MCPTool(**tools[0])))
42+
return tools
43+
44+
async def call_tool(
45+
self, tool_name: str, arguments: Optional[dict[str, Any]]
46+
) -> CallToolResult:
47+
return await workflow.execute_local_activity(
48+
self.name + "-call-tool",
49+
args=[tool_name, arguments],
50+
start_to_close_timeout=timedelta(seconds=30),
51+
result_type=CallToolResult,
52+
)
53+
54+
async def list_prompts(self) -> ListPromptsResult:
55+
raise NotImplementedError()
56+
57+
async def get_prompt(
58+
self, name: str, arguments: Optional[dict[str, Any]] = None
59+
) -> GetPromptResult:
60+
raise NotImplementedError()
61+
62+
63+
class TemporalMCPServer(TemporalMCPServerWorkflowShim):
64+
def __init__(self, server: MCPServer):
65+
self.server = server
66+
super().__init__(server.name)
67+
68+
@property
69+
def name(self) -> str:
70+
return self.server.name
71+
72+
async def connect(self) -> None:
73+
await self.server.connect()
74+
75+
async def cleanup(self) -> None:
76+
await self.server.cleanup()
77+
78+
async def list_tools(
79+
self,
80+
run_context: Optional[RunContextWrapper[Any]] = None,
81+
agent: Optional[AgentBase] = None,
82+
) -> list[MCPTool]:
83+
if not workflow.in_workflow():
84+
return await self.server.list_tools(run_context, agent)
85+
86+
return await super().list_tools(run_context, agent)
87+
88+
async def call_tool(
89+
self, tool_name: str, arguments: Optional[dict[str, Any]]
90+
) -> CallToolResult:
91+
if not workflow.in_workflow():
92+
return await self.server.call_tool(tool_name, arguments)
93+
94+
return await super().call_tool(tool_name, arguments)
95+
96+
async def __aenter__(self):
97+
await self.connect()
98+
return self
99+
100+
async def __aexit__(self, exc_type, exc_value, traceback):
101+
await self.cleanup()
102+
103+
def get_activities(self) -> Sequence[Callable]:
104+
@activity.defn(name=self.name + "-list-tools")
105+
async def list_tools() -> list[MCPTool]:
106+
activity.logger.info("Listing tools in activity")
107+
return await self.server.list_tools()
108+
109+
@activity.defn(name=self.name + "-call-tool")
110+
async def call_tool(
111+
tool_name: str, arguments: Optional[dict[str, Any]]
112+
) -> CallToolResult:
113+
return await self.server.call_tool(tool_name, arguments)
114+
115+
return list_tools, call_tool

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

3-
from contextlib import contextmanager
3+
import dataclasses
4+
from contextlib import AsyncExitStack, contextmanager
45
from datetime import timedelta
5-
from typing import AsyncIterator, Callable, Optional, Union
6+
from typing import AsyncIterator, Callable, Optional, Sequence, Union
67

78
from agents import (
89
AgentOutputSchemaBase,
@@ -17,6 +18,7 @@
1718
set_trace_provider,
1819
)
1920
from agents.items import TResponseStreamEvent
21+
from agents.mcp import MCPServer
2022
from agents.run import get_default_agent_runner, set_default_agent_runner
2123
from agents.tracing import get_trace_provider
2224
from agents.tracing.provider import DefaultTraceProvider
@@ -26,6 +28,7 @@
2628
import temporalio.worker
2729
from temporalio.client import ClientConfig
2830
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
31+
from temporalio.contrib.openai_agents._mcp import TemporalMCPServer
2932
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
3033
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
3134
from temporalio.contrib.openai_agents._temporal_trace_provider import (
@@ -42,6 +45,7 @@
4245
DataConverter,
4346
)
4447
from temporalio.worker import Worker, WorkerConfig
48+
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
4549

4650

4751
@contextmanager
@@ -203,6 +207,7 @@ def __init__(
203207
self,
204208
model_params: Optional[ModelActivityParameters] = None,
205209
model_provider: Optional[ModelProvider] = None,
210+
mcp_servers: Sequence[MCPServer] = (),
206211
) -> None:
207212
"""Initialize the OpenAI agents plugin.
208213
@@ -231,6 +236,13 @@ def __init__(
231236
self._model_params = model_params
232237
self._model_provider = model_provider
233238

239+
self._mcp_servers = [
240+
server
241+
if isinstance(server, TemporalMCPServer)
242+
else TemporalMCPServer(server)
243+
for server in mcp_servers
244+
]
245+
234246
def configure_client(self, config: ClientConfig) -> ClientConfig:
235247
"""Configure the Temporal client for OpenAI agents integration.
236248
@@ -265,9 +277,18 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
265277
config["interceptors"] = list(config.get("interceptors") or []) + [
266278
OpenAIAgentsTracingInterceptor()
267279
]
268-
config["activities"] = list(config.get("activities") or []) + [
269-
ModelActivity(self._model_provider).invoke_model_activity
270-
]
280+
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]
281+
for mcp_server in self._mcp_servers:
282+
new_activities.extend(mcp_server.get_activities())
283+
config["activities"] = list(config.get("activities") or []) + new_activities
284+
285+
runner = config.get("workflow_runner")
286+
if isinstance(runner, SandboxedWorkflowRunner):
287+
config["workflow_runner"] = dataclasses.replace(
288+
runner,
289+
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
290+
)
291+
271292
return super().configure_worker(config)
272293

273294
async def run_worker(self, worker: Worker) -> None:
@@ -281,4 +302,7 @@ async def run_worker(self, worker: Worker) -> None:
281302
worker: The worker instance to run.
282303
"""
283304
with set_open_ai_agent_temporal_overrides(self._model_params):
284-
await super().run_worker(worker)
305+
async with AsyncExitStack() as stack:
306+
for mcp_server in self._mcp_servers:
307+
await stack.enter_async_context(mcp_server)
308+
await super().run_worker(worker)

0 commit comments

Comments
 (0)