Skip to content

Commit 8ef444b

Browse files
committed
Overhaul stateful mcp server
1 parent d0e7355 commit 8ef444b

File tree

4 files changed

+70
-43
lines changed

4 files changed

+70
-43
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# Best Effort mcp, as it is not supported on Python 3.9
1212
try:
1313
from temporalio.contrib.openai_agents._mcp import (
14-
StatefulMCPServer,
14+
StatefulMCPServerProvider,
1515
StatelessMCPServer,
1616
)
1717
except ImportError:
@@ -37,7 +37,7 @@
3737
"OpenAIAgentsPlugin",
3838
"OpenAIPayloadConverter",
3939
"StatelessMCPServer",
40-
"StatefulMCPServer",
40+
"StatefulMCPServerProvider",
4141
"TestModel",
4242
"TestModelProvider",
4343
"workflow",

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def name(self) -> str:
203203
return self._name
204204

205205
async def connect(self) -> None:
206-
self._config["task_queue"] = workflow.info().workflow_id + "-" + self.name
206+
self._config["task_queue"] = self.name + "@" + workflow.info().run_id
207207
self._connect_handle = workflow.start_activity(
208208
self.name + "-server-session",
209209
args=[],
@@ -228,7 +228,9 @@ async def list_tools(
228228
agent: Optional[AgentBase] = None,
229229
) -> list[MCPTool]:
230230
if not self._connect_handle:
231-
raise ApplicationError("Stateful MCP Server not connected. Call connect first.")
231+
raise ApplicationError(
232+
"Stateful MCP Server not connected. Call connect first."
233+
)
232234
return await workflow.execute_activity(
233235
self.name + "-list-tools",
234236
args=[],
@@ -241,7 +243,9 @@ async def call_tool(
241243
self, tool_name: str, arguments: Optional[dict[str, Any]]
242244
) -> CallToolResult:
243245
if not self._connect_handle:
244-
raise ApplicationError("Stateful MCP Server not connected. Call connect first.")
246+
raise ApplicationError(
247+
"Stateful MCP Server not connected. Call connect first."
248+
)
245249
return await workflow.execute_activity(
246250
self.name + "-call-tool",
247251
args=[tool_name, arguments],
@@ -252,7 +256,9 @@ async def call_tool(
252256
@_handle_worker_failure
253257
async def list_prompts(self) -> ListPromptsResult:
254258
if not self._connect_handle:
255-
raise ApplicationError("Stateful MCP Server not connected. Call connect first.")
259+
raise ApplicationError(
260+
"Stateful MCP Server not connected. Call connect first."
261+
)
256262
return await workflow.execute_activity(
257263
self.name + "-list-prompts",
258264
args=[],
@@ -265,7 +271,9 @@ async def get_prompt(
265271
self, name: str, arguments: Optional[dict[str, Any]] = None
266272
) -> GetPromptResult:
267273
if not self._connect_handle:
268-
raise ApplicationError("Stateful MCP Server not connected. Call connect first.")
274+
raise ApplicationError(
275+
"Stateful MCP Server not connected. Call connect first."
276+
)
269277
return await workflow.execute_activity(
270278
self.name + "-get-prompt",
271279
args=[name, arguments],
@@ -274,10 +282,10 @@ async def get_prompt(
274282
)
275283

276284

277-
class StatefulMCPServer:
285+
class StatefulMCPServerProvider:
278286
"""A stateful MCP server implementation for Temporal workflows.
279287
280-
This class wraps an MCP server to maintain a persistent connection throughout
288+
This class wraps an function to create MCP servers to maintain a persistent connection throughout
281289
the workflow execution. It creates a dedicated worker that stays connected to
282290
the MCP server and processes operations on a dedicated task queue.
283291
@@ -292,16 +300,18 @@ class StatefulMCPServer:
292300

293301
def __init__(
294302
self,
295-
server: MCPServer,
303+
server_factory: Callable[[], MCPServer],
296304
):
297305
"""Initialize the stateful temporal MCP server.
298306
299307
Args:
300-
server: Either an MCPServer instance or a string name for the server.
308+
server_factory: A function which will produce MCPServer instances. It should return a new server each time
309+
so that state is not shared between workflow runs
301310
"""
302-
self._server = server
303-
self._name = self._server.name + "-stateful"
311+
self._server_factory = server_factory
312+
self._name = server_factory().name + "-stateful"
304313
self._connect_handle: Optional[ActivityHandle] = None
314+
self._servers: dict[str, MCPServer] = {}
305315
super().__init__()
306316

307317
@property
@@ -310,25 +320,28 @@ def name(self) -> str:
310320
return self._name
311321

312322
def _get_activities(self) -> Sequence[Callable]:
323+
def _server_id():
324+
return self.name + "@" + activity.info().workflow_run_id
325+
313326
@activity.defn(name=self.name + "-list-tools")
314327
async def list_tools() -> list[MCPTool]:
315-
return await self._server.list_tools()
328+
return await self._servers[_server_id()].list_tools()
316329

317330
@activity.defn(name=self.name + "-call-tool")
318331
async def call_tool(
319332
tool_name: str, arguments: Optional[dict[str, Any]]
320333
) -> CallToolResult:
321-
return await self._server.call_tool(tool_name, arguments)
334+
return await self._servers[_server_id()].call_tool(tool_name, arguments)
322335

323336
@activity.defn(name=self.name + "-list-prompts")
324337
async def list_prompts() -> ListPromptsResult:
325-
return await self._server.list_prompts()
338+
return await self._servers[_server_id()].list_prompts()
326339

327340
@activity.defn(name=self.name + "-get-prompt")
328341
async def get_prompt(
329342
name: str, arguments: Optional[dict[str, Any]]
330343
) -> GetPromptResult:
331-
return await self._server.get_prompt(name, arguments)
344+
return await self._servers[_server_id()].get_prompt(name, arguments)
332345

333346
async def heartbeat_every(delay: float, *details: Any) -> None:
334347
"""Heartbeat every so often while not cancelled"""
@@ -339,23 +352,34 @@ async def heartbeat_every(delay: float, *details: Any) -> None:
339352
@activity.defn(name=self._name + "-server-session")
340353
async def connect() -> None:
341354
heartbeat_task = asyncio.create_task(heartbeat_every(30))
342-
try:
343-
await self._server.connect()
344355

345-
worker = Worker(
346-
activity.client(),
347-
task_queue=activity.info().workflow_id + "-" + self.name,
348-
activities=[list_tools, call_tool, list_prompts, get_prompt],
349-
activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1),
356+
server_id = self.name + "@" + activity.info().workflow_run_id
357+
if server_id in self._servers:
358+
raise ApplicationError(
359+
"Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow."
350360
)
351-
352-
await worker.run()
353-
finally:
354-
await self._server.cleanup()
355-
heartbeat_task.cancel()
361+
server = self._server_factory()
362+
try:
363+
self._servers[server_id] = server
356364
try:
357-
await heartbeat_task
358-
except asyncio.CancelledError:
359-
pass
365+
await server.connect()
366+
367+
worker = Worker(
368+
activity.client(),
369+
task_queue=server_id,
370+
activities=[list_tools, call_tool, list_prompts, get_prompt],
371+
activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1),
372+
)
373+
374+
await worker.run()
375+
finally:
376+
await server.cleanup()
377+
heartbeat_task.cancel()
378+
try:
379+
await heartbeat_task
380+
except asyncio.CancelledError:
381+
pass
382+
finally:
383+
del self._servers[server_id]
360384

361385
return (connect,)

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
if typing.TYPE_CHECKING:
6666
from temporalio.contrib.openai_agents import (
67-
StatefulMCPServer,
67+
StatefulMCPServerProvider,
6868
StatelessMCPServer,
6969
)
7070

@@ -242,7 +242,9 @@ def __init__(
242242
self,
243243
model_params: Optional[ModelActivityParameters] = None,
244244
model_provider: Optional[ModelProvider] = None,
245-
mcp_servers: Sequence[Union["StatelessMCPServer", "StatefulMCPServer"]] = (),
245+
mcp_servers: Sequence[
246+
Union["StatelessMCPServer", "StatefulMCPServerProvider"]
247+
] = (),
246248
) -> None:
247249
"""Initialize the OpenAI agents plugin.
248250

tests/contrib/openai_agents/test_openai.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,9 +2341,7 @@ class TrackingMCPModel(StaticTestModel):
23412341

23422342
@pytest.mark.parametrize("use_local_model", [True, False])
23432343
@pytest.mark.parametrize("stateful", [True, False])
2344-
async def test_stateful_mcp_server(
2345-
client: Client, use_local_model: bool, stateful: bool
2346-
):
2344+
async def test_mcp_server(client: Client, use_local_model: bool, stateful: bool):
23472345
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
23482346
pytest.skip("No openai API key")
23492347

@@ -2354,7 +2352,10 @@ async def test_stateful_mcp_server(
23542352
from mcp import Tool as MCPTool # type: ignore
23552353
from mcp.types import CallToolResult, TextContent # type: ignore
23562354

2357-
from temporalio.contrib.openai_agents import StatefulMCPServer, StatelessMCPServer
2355+
from temporalio.contrib.openai_agents import (
2356+
StatefulMCPServerProvider,
2357+
StatelessMCPServer,
2358+
)
23582359

23592360
class TrackingMCPServer(MCPServer):
23602361
calls: list[str]
@@ -2412,8 +2413,8 @@ async def get_prompt(
24122413
raise NotImplementedError()
24132414

24142415
tracking_server = TrackingMCPServer(name="HelloServer")
2415-
server: Union[StatefulMCPServer, StatelessMCPServer] = (
2416-
StatefulMCPServer(tracking_server)
2416+
server: Union[StatefulMCPServerProvider, StatelessMCPServer] = (
2417+
StatefulMCPServerProvider(lambda: tracking_server)
24172418
if stateful
24182419
else StatelessMCPServer(tracking_server)
24192420
)
@@ -2490,10 +2491,10 @@ async def test_stateful_mcp_server_no_worker(client: Client):
24902491
pytest.skip("Mcp not supported on Python 3.9")
24912492
from agents.mcp import MCPServerStdio
24922493

2493-
from temporalio.contrib.openai_agents import StatefulMCPServer
2494+
from temporalio.contrib.openai_agents import StatefulMCPServerProvider
24942495

2495-
server = StatefulMCPServer(
2496-
MCPServerStdio(
2496+
server = StatefulMCPServerProvider(
2497+
lambda: MCPServerStdio(
24972498
name="Filesystem-Server",
24982499
params={
24992500
"command": "npx",

0 commit comments

Comments
 (0)