@@ -203,7 +203,7 @@ def name(self) -> str:
203203 return self ._name
204204
205205 async def connect (self ) -> None :
206- self ._config ["task_queue" ] = workflow . info (). workflow_id + "- " + self . name
206+ self ._config ["task_queue" ] = self . name + "@ " + workflow . info (). run_id
207207 self ._connect_handle = workflow .start_activity (
208208 self .name + "-server-session" ,
209209 args = [],
@@ -228,7 +228,9 @@ async def list_tools(
228228 agent : Optional [AgentBase ] = None ,
229229 ) -> list [MCPTool ]:
230230 if not self ._connect_handle :
231- raise ApplicationError ("Stateful MCP Server not connected. Call connect first." )
231+ raise ApplicationError (
232+ "Stateful MCP Server not connected. Call connect first."
233+ )
232234 return await workflow .execute_activity (
233235 self .name + "-list-tools" ,
234236 args = [],
@@ -241,7 +243,9 @@ async def call_tool(
241243 self , tool_name : str , arguments : Optional [dict [str , Any ]]
242244 ) -> CallToolResult :
243245 if not self ._connect_handle :
244- raise ApplicationError ("Stateful MCP Server not connected. Call connect first." )
246+ raise ApplicationError (
247+ "Stateful MCP Server not connected. Call connect first."
248+ )
245249 return await workflow .execute_activity (
246250 self .name + "-call-tool" ,
247251 args = [tool_name , arguments ],
@@ -252,7 +256,9 @@ async def call_tool(
252256 @_handle_worker_failure
253257 async def list_prompts (self ) -> ListPromptsResult :
254258 if not self ._connect_handle :
255- raise ApplicationError ("Stateful MCP Server not connected. Call connect first." )
259+ raise ApplicationError (
260+ "Stateful MCP Server not connected. Call connect first."
261+ )
256262 return await workflow .execute_activity (
257263 self .name + "-list-prompts" ,
258264 args = [],
@@ -265,7 +271,9 @@ async def get_prompt(
265271 self , name : str , arguments : Optional [dict [str , Any ]] = None
266272 ) -> GetPromptResult :
267273 if not self ._connect_handle :
268- raise ApplicationError ("Stateful MCP Server not connected. Call connect first." )
274+ raise ApplicationError (
275+ "Stateful MCP Server not connected. Call connect first."
276+ )
269277 return await workflow .execute_activity (
270278 self .name + "-get-prompt" ,
271279 args = [name , arguments ],
@@ -274,10 +282,10 @@ async def get_prompt(
274282 )
275283
276284
277- class StatefulMCPServer :
285+ class StatefulMCPServerProvider :
278286 """A stateful MCP server implementation for Temporal workflows.
279287
280- This class wraps an MCP server to maintain a persistent connection throughout
288+ This class wraps an function to create MCP servers to maintain a persistent connection throughout
281289 the workflow execution. It creates a dedicated worker that stays connected to
282290 the MCP server and processes operations on a dedicated task queue.
283291
@@ -292,16 +300,18 @@ class StatefulMCPServer:
292300
293301 def __init__ (
294302 self ,
295- server : MCPServer ,
303+ server_factory : Callable [[], MCPServer ] ,
296304 ):
297305 """Initialize the stateful temporal MCP server.
298306
299307 Args:
300- server: Either an MCPServer instance or a string name for the server.
308+ server_factory: A function which will produce MCPServer instances. It should return a new server each time
309+ so that state is not shared between workflow runs
301310 """
302- self ._server = server
303- self ._name = self . _server .name + "-stateful"
311+ self ._server_factory = server_factory
312+ self ._name = server_factory () .name + "-stateful"
304313 self ._connect_handle : Optional [ActivityHandle ] = None
314+ self ._servers : dict [str , MCPServer ] = {}
305315 super ().__init__ ()
306316
307317 @property
@@ -310,25 +320,28 @@ def name(self) -> str:
310320 return self ._name
311321
312322 def _get_activities (self ) -> Sequence [Callable ]:
323+ def _server_id ():
324+ return self .name + "@" + activity .info ().workflow_run_id
325+
313326 @activity .defn (name = self .name + "-list-tools" )
314327 async def list_tools () -> list [MCPTool ]:
315- return await self ._server .list_tools ()
328+ return await self ._servers [ _server_id ()] .list_tools ()
316329
317330 @activity .defn (name = self .name + "-call-tool" )
318331 async def call_tool (
319332 tool_name : str , arguments : Optional [dict [str , Any ]]
320333 ) -> CallToolResult :
321- return await self ._server .call_tool (tool_name , arguments )
334+ return await self ._servers [ _server_id ()] .call_tool (tool_name , arguments )
322335
323336 @activity .defn (name = self .name + "-list-prompts" )
324337 async def list_prompts () -> ListPromptsResult :
325- return await self ._server .list_prompts ()
338+ return await self ._servers [ _server_id ()] .list_prompts ()
326339
327340 @activity .defn (name = self .name + "-get-prompt" )
328341 async def get_prompt (
329342 name : str , arguments : Optional [dict [str , Any ]]
330343 ) -> GetPromptResult :
331- return await self ._server .get_prompt (name , arguments )
344+ return await self ._servers [ _server_id ()] .get_prompt (name , arguments )
332345
333346 async def heartbeat_every (delay : float , * details : Any ) -> None :
334347 """Heartbeat every so often while not cancelled"""
@@ -339,23 +352,34 @@ async def heartbeat_every(delay: float, *details: Any) -> None:
339352 @activity .defn (name = self ._name + "-server-session" )
340353 async def connect () -> None :
341354 heartbeat_task = asyncio .create_task (heartbeat_every (30 ))
342- try :
343- await self ._server .connect ()
344355
345- worker = Worker (
346- activity .client (),
347- task_queue = activity .info ().workflow_id + "-" + self .name ,
348- activities = [list_tools , call_tool , list_prompts , get_prompt ],
349- activity_task_poller_behavior = PollerBehaviorSimpleMaximum (1 ),
356+ server_id = self .name + "@" + activity .info ().workflow_run_id
357+ if server_id in self ._servers :
358+ raise ApplicationError (
359+ "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow."
350360 )
351-
352- await worker .run ()
353- finally :
354- await self ._server .cleanup ()
355- heartbeat_task .cancel ()
361+ server = self ._server_factory ()
362+ try :
363+ self ._servers [server_id ] = server
356364 try :
357- await heartbeat_task
358- except asyncio .CancelledError :
359- pass
365+ await server .connect ()
366+
367+ worker = Worker (
368+ activity .client (),
369+ task_queue = server_id ,
370+ activities = [list_tools , call_tool , list_prompts , get_prompt ],
371+ activity_task_poller_behavior = PollerBehaviorSimpleMaximum (1 ),
372+ )
373+
374+ await worker .run ()
375+ finally :
376+ await server .cleanup ()
377+ heartbeat_task .cancel ()
378+ try :
379+ await heartbeat_task
380+ except asyncio .CancelledError :
381+ pass
382+ finally :
383+ del self ._servers [server_id ]
360384
361385 return (connect ,)
0 commit comments