Skip to content

Commit 4a4378c

Browse files
committed
Restructure based on feedback - docstring need updating still
1 parent c441326 commit 4a4378c

File tree

5 files changed

+190
-99
lines changed

5 files changed

+190
-99
lines changed

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 133 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import abc
12
import asyncio
3+
import functools
24
import logging
3-
import uuid
45
from datetime import timedelta
56
from typing import Any, Callable, Optional, Sequence, Union
67

@@ -22,7 +23,15 @@
2223
logger = logging.getLogger(__name__)
2324

2425

25-
class StatelessTemporalMCPServer(MCPServer):
26+
class TemporalMCPServer(abc.ABC):
27+
@property
28+
@abc.abstractmethod
29+
def name(self) -> str:
30+
"""Get the server name."""
31+
raise NotImplementedError()
32+
33+
34+
class StatelessTemporalMCPServerReference(MCPServer):
2635
"""A stateless MCP server implementation for Temporal workflows.
2736
2837
This class wraps an MCP server to make it stateless by executing each MCP operation
@@ -33,19 +42,16 @@ class StatelessTemporalMCPServer(MCPServer):
3342
and you don't need to maintain state between operations.
3443
"""
3544

36-
def __init__(
37-
self, server: Union[MCPServer, str], config: Optional[ActivityConfig] = None
38-
):
45+
def __init__(self, server: str, config: Optional[ActivityConfig] = None):
3946
"""Initialize the stateless temporal MCP server.
4047
4148
Args:
4249
server: Either an MCPServer instance or a string name for the server.
4350
config: Optional activity configuration for Temporal activities. Defaults to
4451
1-minute start-to-close timeout if not provided.
4552
"""
46-
self.server = server if isinstance(server, MCPServer) else None
47-
self._name = (server if isinstance(server, str) else server.name) + "-stateless"
48-
self.config = config or ActivityConfig(
53+
self._name = server + "-stateless"
54+
self._config = config or ActivityConfig(
4955
start_to_close_timeout=timedelta(minutes=1)
5056
)
5157
super().__init__()
@@ -99,7 +105,7 @@ async def list_tools(
99105
self.name + "-list-tools",
100106
args=[],
101107
result_type=list[MCPTool],
102-
**self.config,
108+
**self._config,
103109
)
104110

105111
async def call_tool(
@@ -124,7 +130,7 @@ async def call_tool(
124130
self.name + "-call-tool",
125131
args=[tool_name, arguments],
126132
result_type=CallToolResult,
127-
**self.config,
133+
**self._config,
128134
)
129135

130136
async def list_prompts(self) -> ListPromptsResult:
@@ -143,7 +149,7 @@ async def list_prompts(self) -> ListPromptsResult:
143149
self.name + "-list-prompts",
144150
args=[],
145151
result_type=ListPromptsResult,
146-
**self.config,
152+
**self._config,
147153
)
148154

149155
async def get_prompt(
@@ -168,9 +174,37 @@ async def get_prompt(
168174
self.name + "-get-prompt",
169175
args=[name, arguments],
170176
result_type=GetPromptResult,
171-
**self.config,
177+
**self._config,
172178
)
173179

180+
181+
class StatelessTemporalMCPServer(TemporalMCPServer):
182+
"""A stateless MCP server implementation for Temporal workflows.
183+
184+
This class wraps an MCP server to make it stateless by executing each MCP operation
185+
as a separate Temporal activity. Each operation (list_tools, call_tool, etc.) will
186+
connect to the underlying server, execute the operation, and then clean up the connection.
187+
188+
This approach is suitable for simple use cases where connection overhead is acceptable
189+
and you don't need to maintain state between operations. It is encouraged when possible as it provides
190+
a better set of durability guarantees that the stateful version.
191+
"""
192+
193+
def __init__(self, server: MCPServer):
194+
"""Initialize the stateless temporal MCP server.
195+
196+
Args:
197+
server: An MCPServer instance
198+
"""
199+
self._server = server
200+
self._name = server.name + "-stateless"
201+
super().__init__()
202+
203+
@property
204+
def name(self) -> str:
205+
"""Get the server name."""
206+
return self._name
207+
174208
def get_activities(self) -> Sequence[Callable]:
175209
"""Get the Temporal activities for this MCP server.
176210
@@ -183,11 +217,7 @@ def get_activities(self) -> Sequence[Callable]:
183217
Raises:
184218
ValueError: If no MCP server instance was provided during initialization.
185219
"""
186-
server = self.server
187-
if server is None:
188-
raise ValueError(
189-
"A full MCPServer implementation should have been provided when adding a server to the worker."
190-
)
220+
server = self._server
191221

192222
@activity.defn(name=self.name + "-list-tools")
193223
async def list_tools() -> list[MCPTool]:
@@ -228,43 +258,57 @@ async def get_prompt(
228258
return list_tools, call_tool, list_prompts, get_prompt
229259

230260

231-
class StatefulTemporalMCPServer(MCPServer):
232-
"""A stateful MCP server implementation for Temporal workflows.
233-
234-
This class wraps an MCP server to maintain a persistent connection throughout
235-
the workflow execution. It creates a dedicated worker that stays connected to
236-
the MCP server and processes operations on a dedicated task queue.
261+
def _handle_worker_failure(func):
262+
@functools.wraps(func)
263+
async def wrapper(*args, **kwargs):
264+
try:
265+
return await func(*args, **kwargs)
266+
except ActivityError as e:
267+
failure = e.failure
268+
if failure:
269+
cause = failure.cause
270+
if cause:
271+
if (
272+
cause.timeout_failure_info.timeout_type
273+
== TIMEOUT_TYPE_SCHEDULE_TO_START
274+
):
275+
raise ApplicationError(
276+
"MCP Stateful Server Worker failed to schedule activity."
277+
) from e
278+
if (
279+
cause.timeout_failure_info.timeout_type
280+
== TIMEOUT_TYPE_HEARTBEAT
281+
):
282+
raise ApplicationError(
283+
"MCP Stateful Server Worker failed to heartbeat."
284+
) from e
285+
raise e
237286

238-
This approach is more efficient for workflows that make multiple MCP calls,
239-
as it avoids connection overhead, but requires more resources to maintain
240-
the persistent connection and worker.
287+
return wrapper
241288

242-
The caller will have to handle cases where the dedicated worker fails, as Temporal is
243-
unable to seamlessly recreate any lost state in that case.
244-
"""
245289

290+
class StatefulTemporalMCPServerReference(MCPServer):
246291
def __init__(
247292
self,
248-
server: Union[MCPServer, str],
293+
server: str,
249294
config: Optional[ActivityConfig] = None,
250295
connect_config: Optional[ActivityConfig] = None,
251296
):
252297
"""Initialize the stateful temporal MCP server.
253298
254299
Args:
255-
server: Either an MCPServer instance or a string name for the server.
300+
server: A string name for the server. Should match that provided in the plugin.
256301
config: Optional activity configuration for MCP operation activities.
257302
Defaults to 1-minute start-to-close and 30-second schedule-to-start timeouts.
258303
connect_config: Optional activity configuration for the connection activity.
259304
Defaults to 1-hour start-to-close timeout.
260305
"""
261-
self.server = server if isinstance(server, MCPServer) else None
262-
self._name = (server if isinstance(server, str) else server.name) + "-stateful"
263-
self.config = config or ActivityConfig(
306+
self._name = server + "-stateful"
307+
self._config = config or ActivityConfig(
264308
start_to_close_timeout=timedelta(minutes=1),
265309
schedule_to_start_timeout=timedelta(seconds=30),
266310
)
267-
self.connect_config = connect_config or ActivityConfig(
311+
self._connect_config = connect_config or ActivityConfig(
268312
start_to_close_timeout=timedelta(hours=1),
269313
)
270314
self._connect_handle: Optional[ActivityHandle] = None
@@ -286,11 +330,11 @@ async def connect(self) -> None:
286330
a long-running activity that maintains the connection and runs a worker
287331
to handle MCP operations.
288332
"""
289-
self.config["task_queue"] = workflow.info().workflow_id + "-" + self.name
333+
self._config["task_queue"] = workflow.info().workflow_id + "-" + self.name
290334
self._connect_handle = workflow.start_activity(
291335
self.name + "-connect",
292336
args=[],
293-
**self.connect_config,
337+
**self._connect_config,
294338
)
295339

296340
async def cleanup(self) -> None:
@@ -322,6 +366,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
322366
"""
323367
await self.cleanup()
324368

369+
@_handle_worker_failure
325370
async def list_tools(
326371
self,
327372
run_context: Optional[RunContextWrapper[Any]] = None,
@@ -343,35 +388,14 @@ async def list_tools(
343388
ApplicationError: If the MCP worker fails to schedule or heartbeat.
344389
ActivityError: If the underlying Temporal activity fails.
345390
"""
346-
try:
347-
logger.info("Executing list-tools: %s", self.config)
348-
return await workflow.execute_activity(
349-
self.name + "-list-tools",
350-
args=[],
351-
result_type=list[MCPTool],
352-
**self.config,
353-
)
354-
except ActivityError as e:
355-
failure = e.failure
356-
if failure:
357-
cause = failure.cause
358-
if cause:
359-
if (
360-
cause.timeout_failure_info.timeout_type
361-
== TIMEOUT_TYPE_SCHEDULE_TO_START
362-
):
363-
raise ApplicationError(
364-
"MCP Stateful Server Worker failed to schedule activity."
365-
) from e
366-
if (
367-
cause.timeout_failure_info.timeout_type
368-
== TIMEOUT_TYPE_HEARTBEAT
369-
):
370-
raise ApplicationError(
371-
"MCP Stateful Server Worker failed to heartbeat."
372-
) from e
373-
raise e
391+
return await workflow.execute_activity(
392+
self.name + "-list-tools",
393+
args=[],
394+
result_type=list[MCPTool],
395+
**self._config,
396+
)
374397

398+
@_handle_worker_failure
375399
async def call_tool(
376400
self, tool_name: str, arguments: Optional[dict[str, Any]]
377401
) -> CallToolResult:
@@ -394,9 +418,10 @@ async def call_tool(
394418
self.name + "-call-tool",
395419
args=[tool_name, arguments],
396420
result_type=CallToolResult,
397-
**self.config,
421+
**self._config,
398422
)
399423

424+
@_handle_worker_failure
400425
async def list_prompts(self) -> ListPromptsResult:
401426
"""List available prompts from the MCP server.
402427
@@ -413,9 +438,10 @@ async def list_prompts(self) -> ListPromptsResult:
413438
self.name + "-list-prompts",
414439
args=[],
415440
result_type=ListPromptsResult,
416-
**self.config,
441+
**self._config,
417442
)
418443

444+
@_handle_worker_failure
419445
async def get_prompt(
420446
self, name: str, arguments: Optional[dict[str, Any]] = None
421447
) -> GetPromptResult:
@@ -438,9 +464,46 @@ async def get_prompt(
438464
self.name + "-get-prompt",
439465
args=[name, arguments],
440466
result_type=GetPromptResult,
441-
**self.config,
467+
**self._config,
442468
)
443469

470+
471+
class StatefulTemporalMCPServer(TemporalMCPServer):
472+
"""A stateful MCP server implementation for Temporal workflows.
473+
474+
This class wraps an MCP server to maintain a persistent connection throughout
475+
the workflow execution. It creates a dedicated worker that stays connected to
476+
the MCP server and processes operations on a dedicated task queue.
477+
478+
This approach is more efficient for workflows that make multiple MCP calls,
479+
as it avoids connection overhead, but requires more resources to maintain
480+
the persistent connection and worker.
481+
482+
The caller will have to handle cases where the dedicated worker fails, as Temporal is
483+
unable to seamlessly recreate any lost state in that case.
484+
"""
485+
486+
def __init__(
487+
self,
488+
server: MCPServer,
489+
):
490+
"""Initialize the stateful temporal MCP server.
491+
492+
Args:
493+
server: Either an MCPServer instance or a string name for the server.
494+
connect_config: Optional activity configuration for the connection activity.
495+
Defaults to 1-hour start-to-close timeout.
496+
"""
497+
self._server = server
498+
self._name = self._server.name + "-stateful"
499+
self._connect_handle: Optional[ActivityHandle] = None
500+
super().__init__()
501+
502+
@property
503+
def name(self) -> str:
504+
"""Get the server name."""
505+
return self._name
506+
444507
def get_activities(self) -> Sequence[Callable]:
445508
"""Get the Temporal activities for this stateful MCP server.
446509
@@ -454,11 +517,7 @@ def get_activities(self) -> Sequence[Callable]:
454517
Raises:
455518
ValueError: If no MCP server instance was provided during initialization.
456519
"""
457-
server = self.server
458-
if server is None:
459-
raise ValueError(
460-
"A full MCPServer implementation should have been provided when adding a server to the worker."
461-
)
520+
server = self._server
462521

463522
@activity.defn(name=self.name + "-list-tools")
464523
async def list_tools() -> list[MCPTool]:
@@ -486,9 +545,8 @@ async def heartbeat_every(delay: float, *details: Any) -> None:
486545
await asyncio.sleep(delay)
487546
activity.heartbeat(*details)
488547

489-
@activity.defn(name=self.name + "-connect")
548+
@activity.defn(name=self._name + "-connect")
490549
async def connect() -> None:
491-
logger.info("Connect activity")
492550
heartbeat_task = asyncio.create_task(heartbeat_every(30))
493551
try:
494552
await server.connect()

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,18 @@ async def run(
5656
)
5757

5858
if starting_agent.mcp_servers:
59-
from temporalio.contrib.openai_agents import (
60-
StatefulTemporalMCPServer,
61-
StatelessTemporalMCPServer,
59+
from temporalio.contrib.openai_agents._mcp import (
60+
StatefulTemporalMCPServerReference,
61+
StatelessTemporalMCPServerReference,
6262
)
6363

6464
for s in starting_agent.mcp_servers:
6565
if not isinstance(
66-
s, (StatelessTemporalMCPServer, StatefulTemporalMCPServer)
66+
s,
67+
(
68+
StatelessTemporalMCPServerReference,
69+
StatefulTemporalMCPServerReference,
70+
),
6771
):
6872
warnings.warn(
6973
"Unknown mcp_server type {} may not work durably.".format(

0 commit comments

Comments
 (0)