Skip to content

Commit 0fd7886

Browse files
committed
Include stateful option
1 parent 0ed952a commit 0fd7886

File tree

4 files changed

+399
-54
lines changed

4 files changed

+399
-54
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# Best Effort mcp, as it is not supported on Python 3.9
1212
try:
1313
from temporalio.contrib.openai_agents._mcp import (
14-
TemporalMCPServer,
15-
TemporalMCPServerWorkflowShim,
14+
StatefulTemporalMCPServer,
15+
StatelessTemporalMCPServer,
1616
)
1717
except ImportError:
1818
pass
@@ -33,8 +33,8 @@
3333
"OpenAIAgentsPlugin",
3434
"ModelActivityParameters",
3535
"workflow",
36-
"TemporalMCPServer",
37-
"TemporalMCPServerWorkflowShim",
36+
"StatelessTemporalMCPServer",
37+
"StatefulTemporalMCPServer",
3838
"TestModel",
3939
"TestModelProvider",
4040
]

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 207 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import asyncio
2+
import logging
3+
import uuid
14
from datetime import timedelta
25
from typing import Any, Callable, Optional, Sequence, Union
36

@@ -8,13 +11,26 @@
811
from mcp.types import CallToolResult # type:ignore
912

1013
from temporalio import activity, workflow
11-
from temporalio.workflow import ActivityConfig
14+
from temporalio.api.enums.v1.workflow_pb2 import (
15+
TIMEOUT_TYPE_HEARTBEAT,
16+
TIMEOUT_TYPE_SCHEDULE_TO_START,
17+
)
18+
from temporalio.exceptions import ActivityError, ApplicationError
19+
from temporalio.worker import PollerBehaviorSimpleMaximum, Worker
20+
from temporalio.workflow import ActivityConfig, ActivityHandle
21+
22+
logger = logging.getLogger(__name__)
23+
1224

1325
class StatelessTemporalMCPServer(MCPServer):
14-
def __init__(self, server: Union[MCPServer, str], config: Optional[ActivityConfig] = None):
26+
def __init__(
27+
self, server: Union[MCPServer, str], config: Optional[ActivityConfig] = None
28+
):
1529
self.server = server if isinstance(server, MCPServer) else None
1630
self._name = (server if isinstance(server, str) else server.name) + "-stateless"
17-
self.config = config or ActivityConfig(start_to_close_timeout=timedelta(minutes=1))
31+
self.config = config or ActivityConfig(
32+
start_to_close_timeout=timedelta(minutes=1)
33+
)
1834
super().__init__()
1935

2036
@property
@@ -68,32 +84,209 @@ async def get_prompt(
6884
)
6985

7086
def get_activities(self) -> Sequence[Callable]:
71-
if self.server is None:
72-
raise ValueError("A full MCPServer implementation should have been provided when adding a server to the worker.")
87+
server = self.server
88+
if server is None:
89+
raise ValueError(
90+
"A full MCPServer implementation should have been provided when adding a server to the worker."
91+
)
7392

7493
@activity.defn(name=self.name + "-list-tools")
7594
async def list_tools() -> list[MCPTool]:
76-
activity.logger.info("Listing tools in activity")
77-
async with self.server:
78-
return await self.server.list_tools()
95+
try:
96+
await server.connect()
97+
return await server.list_tools()
98+
finally:
99+
await server.cleanup()
79100

80101
@activity.defn(name=self.name + "-call-tool")
81102
async def call_tool(
82103
tool_name: str, arguments: Optional[dict[str, Any]]
83104
) -> CallToolResult:
84-
async with self.server:
85-
return await self.server.call_tool(tool_name, arguments)
105+
try:
106+
await server.connect()
107+
return await server.call_tool(tool_name, arguments)
108+
finally:
109+
await server.cleanup()
86110

87111
@activity.defn(name=self.name + "-list-prompts")
88112
async def list_prompts() -> ListPromptsResult:
89-
async with self.server:
90-
return await self.server.list_prompts()
113+
try:
114+
await server.connect()
115+
return await server.list_prompts()
116+
finally:
117+
await server.cleanup()
91118

92119
@activity.defn(name=self.name + "-get-prompt")
93120
async def get_prompt(
94121
name: str, arguments: Optional[dict[str, Any]]
95122
) -> GetPromptResult:
96-
async with self.server:
97-
return await self.server.get_prompt(name, arguments)
123+
try:
124+
await server.connect()
125+
return await server.get_prompt(name, arguments)
126+
finally:
127+
await server.cleanup()
98128

99129
return list_tools, call_tool, list_prompts, get_prompt
130+
131+
132+
class StatefulTemporalMCPServer(MCPServer):
133+
def __init__(
134+
self,
135+
server: Union[MCPServer, str],
136+
config: Optional[ActivityConfig] = None,
137+
connect_config: Optional[ActivityConfig] = None,
138+
):
139+
self.server = server if isinstance(server, MCPServer) else None
140+
self._name = (server if isinstance(server, str) else server.name) + "-stateful"
141+
self.config = config or ActivityConfig(
142+
start_to_close_timeout=timedelta(minutes=1),
143+
schedule_to_start_timeout=timedelta(seconds=30),
144+
)
145+
self.connect_config = connect_config or ActivityConfig(
146+
start_to_close_timeout=timedelta(hours=1),
147+
)
148+
self._connect_handle: Optional[ActivityHandle] = None
149+
super().__init__()
150+
151+
@property
152+
def name(self) -> str:
153+
return self._name
154+
155+
async def connect(self) -> None:
156+
self.config["task_queue"] = workflow.info().workflow_id + "-" + self.name
157+
self._connect_handle = workflow.start_activity(
158+
self.name + "-connect",
159+
args=[],
160+
**self.connect_config,
161+
)
162+
163+
async def cleanup(self) -> None:
164+
if self._connect_handle:
165+
self._connect_handle.cancel()
166+
167+
async def __aenter__(self):
168+
await self.connect()
169+
return self
170+
171+
async def __aexit__(self, exc_type, exc_value, traceback):
172+
await self.cleanup()
173+
174+
async def list_tools(
175+
self,
176+
run_context: Optional[RunContextWrapper[Any]] = None,
177+
agent: Optional[AgentBase] = None,
178+
) -> list[MCPTool]:
179+
try:
180+
logger.info("Executing list-tools: %s", self.config)
181+
return await workflow.execute_activity(
182+
self.name + "-list-tools",
183+
args=[],
184+
result_type=list[MCPTool],
185+
**self.config,
186+
)
187+
except ActivityError as e:
188+
failure = e.failure
189+
if failure:
190+
cause = failure.cause
191+
if cause:
192+
if (
193+
cause.timeout_failure_info.timeout_type
194+
== TIMEOUT_TYPE_SCHEDULE_TO_START
195+
):
196+
raise ApplicationError(
197+
"MCP Stateful Server Worker failed to schedule activity."
198+
) from e
199+
if (
200+
cause.timeout_failure_info.timeout_type
201+
== TIMEOUT_TYPE_HEARTBEAT
202+
):
203+
raise ApplicationError(
204+
"MCP Stateful Server Worker failed to heartbeat."
205+
) from e
206+
raise e
207+
208+
async def call_tool(
209+
self, tool_name: str, arguments: Optional[dict[str, Any]]
210+
) -> CallToolResult:
211+
return await workflow.execute_activity(
212+
self.name + "-call-tool",
213+
args=[tool_name, arguments],
214+
result_type=CallToolResult,
215+
**self.config,
216+
)
217+
218+
async def list_prompts(self) -> ListPromptsResult:
219+
return await workflow.execute_activity(
220+
self.name + "-list-prompts",
221+
args=[],
222+
result_type=ListPromptsResult,
223+
**self.config,
224+
)
225+
226+
async def get_prompt(
227+
self, name: str, arguments: Optional[dict[str, Any]] = None
228+
) -> GetPromptResult:
229+
return await workflow.execute_activity(
230+
self.name + "-get-prompt",
231+
args=[name, arguments],
232+
result_type=GetPromptResult,
233+
**self.config,
234+
)
235+
236+
def get_activities(self) -> Sequence[Callable]:
237+
server = self.server
238+
if server is None:
239+
raise ValueError(
240+
"A full MCPServer implementation should have been provided when adding a server to the worker."
241+
)
242+
243+
@activity.defn(name=self.name + "-list-tools")
244+
async def list_tools() -> list[MCPTool]:
245+
return await server.list_tools()
246+
247+
@activity.defn(name=self.name + "-call-tool")
248+
async def call_tool(
249+
tool_name: str, arguments: Optional[dict[str, Any]]
250+
) -> CallToolResult:
251+
return await server.call_tool(tool_name, arguments)
252+
253+
@activity.defn(name=self.name + "-list-prompts")
254+
async def list_prompts() -> ListPromptsResult:
255+
return await server.list_prompts()
256+
257+
@activity.defn(name=self.name + "-get-prompt")
258+
async def get_prompt(
259+
name: str, arguments: Optional[dict[str, Any]]
260+
) -> GetPromptResult:
261+
return await server.get_prompt(name, arguments)
262+
263+
async def heartbeat_every(delay: float, *details: Any) -> None:
264+
"""Heartbeat every so often while not cancelled"""
265+
while True:
266+
await asyncio.sleep(delay)
267+
activity.heartbeat(*details)
268+
269+
@activity.defn(name=self.name + "-connect")
270+
async def connect() -> None:
271+
logger.info("Connect activity")
272+
heartbeat_task = asyncio.create_task(heartbeat_every(30))
273+
try:
274+
await server.connect()
275+
276+
worker = Worker(
277+
activity.client(),
278+
task_queue=activity.info().workflow_id + "-" + self.name,
279+
activities=[list_tools, call_tool, list_prompts, get_prompt],
280+
activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1),
281+
)
282+
283+
await worker.run()
284+
finally:
285+
await server.cleanup()
286+
heartbeat_task.cancel()
287+
try:
288+
await heartbeat_task
289+
except asyncio.CancelledError:
290+
pass
291+
292+
return (connect,)

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
import dataclasses
4-
import typing
5-
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
4+
import warnings
5+
from contextlib import asynccontextmanager, contextmanager
66
from datetime import timedelta
77
from typing import AsyncIterator, Callable, Optional, Sequence, Union
88

@@ -28,6 +28,7 @@
2828
import temporalio.worker
2929
from temporalio.client import ClientConfig
3030
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
31+
3132
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
3233
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
3334
from temporalio.contrib.openai_agents._temporal_trace_provider import (
@@ -52,10 +53,15 @@
5253
)
5354
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
5455

55-
from agents.mcp import MCPServer
56-
57-
from temporalio.contrib.openai_agents._mcp import StatelessTemporalMCPServer
58-
56+
# Unsupported on python 3.9
57+
try:
58+
from agents.mcp import MCPServer
59+
from temporalio.contrib.openai_agents._mcp import (
60+
StatefulTemporalMCPServer,
61+
StatelessTemporalMCPServer,
62+
)
63+
except ImportError:
64+
pass
5965

6066
@contextmanager
6167
def set_open_ai_agent_temporal_overrides(
@@ -165,18 +171,6 @@ def __init__(self) -> None:
165171
super().__init__(ToJsonOptions(exclude_unset=True))
166172

167173

168-
def _transform_mcp_servers(mcp_servers: Sequence[MCPServer]) -> list[MCPServer]:
169-
def _transform_mcp_server(server: MCPServer) -> MCPServer:
170-
if isinstance(server, StatelessTemporalMCPServer):
171-
return server
172-
else:
173-
raise TypeError(f"Unsupported mcp server type {type(server)}")
174-
return [
175-
_transform_mcp_server(server)
176-
for server in mcp_servers
177-
]
178-
179-
180174
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
181175
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
182176
@@ -281,7 +275,18 @@ def __init__(
281275
self._model_provider = model_provider
282276

283277
if mcp_servers:
284-
self._mcp_servers = _transform_mcp_servers(mcp_servers)
278+
def _transform_mcp_server(server: "MCPServer") -> "MCPServer":
279+
if not (
280+
isinstance(server, StatelessTemporalMCPServer)
281+
or isinstance(server, StatefulTemporalMCPServer)
282+
):
283+
warnings.warn(
284+
f"Unsupported mcp server type {type(server)} is not guaranteed to behave reasonably."
285+
)
286+
287+
return server
288+
289+
self._mcp_servers = [_transform_mcp_server(server) for server in mcp_servers]
285290
else:
286291
self._mcp_servers = []
287292

@@ -335,7 +340,11 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
335340
]
336341
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]
337342
for mcp_server in self._mcp_servers:
338-
new_activities.extend(mcp_server.get_activities())
343+
if hasattr(mcp_server, "get_activities"):
344+
get_activities: Callable[[], Sequence[Callable]] = getattr(
345+
mcp_server, "get_activities"
346+
)
347+
new_activities.extend(get_activities())
339348
config["activities"] = list(config.get("activities") or []) + new_activities
340349

341350
runner = config.get("workflow_runner")

0 commit comments

Comments
 (0)