3030
3131
3232class _StatelessMCPServerReference (MCPServer ):
33- def __init__ (self , server : str , config : Optional [ActivityConfig ] = None ):
33+ def __init__ (
34+ self ,
35+ server : str ,
36+ config : Optional [ActivityConfig ],
37+ cache_tools_list : bool ,
38+ ):
3439 self ._name = server + "-stateless"
3540 self ._config = config or ActivityConfig (
3641 start_to_close_timeout = timedelta (minutes = 1 )
3742 )
43+ self ._cache_tools_list = cache_tools_list
44+ self ._tools = None
3845 super ().__init__ ()
3946
4047 @property
@@ -52,12 +59,17 @@ async def list_tools(
5259 run_context : Optional [RunContextWrapper [Any ]] = None ,
5360 agent : Optional [AgentBase ] = None ,
5461 ) -> list [MCPTool ]:
55- return await workflow .execute_activity (
62+ if self ._tools :
63+ return self ._tools
64+ tools = await workflow .execute_activity (
5665 self .name + "-list-tools" ,
5766 args = [],
5867 result_type = list [MCPTool ],
5968 ** self ._config ,
6069 )
70+ if self ._cache_tools_list :
71+ self ._tools = tools
72+ return tools
6173
6274 async def call_tool (
6375 self , tool_name : str , arguments : Optional [dict [str , Any ]]
@@ -88,25 +100,26 @@ async def get_prompt(
88100 )
89101
90102
91- class StatelessMCPServer :
103+ class StatelessMCPServerProvider :
92104 """A stateless MCP server implementation for Temporal workflows.
93105
94- This class wraps an MCP server to make it stateless by executing each MCP operation
106+ This class wraps a function to create MCP servers to make them stateless by executing each MCP operation
95107 as a separate Temporal activity. Each operation (list_tools, call_tool, etc.) will
96108 connect to the underlying server, execute the operation, and then clean up the connection.
97109
98110 This approach will not maintain state across calls. If the desired MCPServer needs persistent state in order to
99111 function, this cannot be used.
100112 """
101113
102- def __init__ (self , server : MCPServer ):
114+ def __init__ (self , server_factory : Callable [[], MCPServer ] ):
103115 """Initialize the stateless temporal MCP server.
104116
105117 Args:
106- server: An MCPServer instance
118+ server_factory: A function which will produce MCPServer instances. It should return a new server each time
119+ so that state is not shared between workflow runs
107120 """
108- self ._server = server
109- self ._name = server .name + "-stateless"
121+ self ._server_factory = server_factory
122+ self ._name = server_factory () .name + "-stateless"
110123 super ().__init__ ()
111124
112125 @property
@@ -117,39 +130,43 @@ def name(self) -> str:
117130 def _get_activities (self ) -> Sequence [Callable ]:
118131 @activity .defn (name = self .name + "-list-tools" )
119132 async def list_tools () -> list [MCPTool ]:
133+ server = self ._server_factory ()
120134 try :
121- await self . _server .connect ()
122- return await self . _server .list_tools ()
135+ await server .connect ()
136+ return await server .list_tools ()
123137 finally :
124- await self . _server .cleanup ()
138+ await server .cleanup ()
125139
126140 @activity .defn (name = self .name + "-call-tool" )
127141 async def call_tool (
128142 tool_name : str , arguments : Optional [dict [str , Any ]]
129143 ) -> CallToolResult :
144+ server = self ._server_factory ()
130145 try :
131- await self . _server .connect ()
132- return await self . _server .call_tool (tool_name , arguments )
146+ await server .connect ()
147+ return await server .call_tool (tool_name , arguments )
133148 finally :
134- await self . _server .cleanup ()
149+ await server .cleanup ()
135150
136151 @activity .defn (name = self .name + "-list-prompts" )
137152 async def list_prompts () -> ListPromptsResult :
153+ server = self ._server_factory ()
138154 try :
139- await self . _server .connect ()
140- return await self . _server .list_prompts ()
155+ await server .connect ()
156+ return await server .list_prompts ()
141157 finally :
142- await self . _server .cleanup ()
158+ await server .cleanup ()
143159
144160 @activity .defn (name = self .name + "-get-prompt" )
145161 async def get_prompt (
146162 name : str , arguments : Optional [dict [str , Any ]]
147163 ) -> GetPromptResult :
164+ server = self ._server_factory ()
148165 try :
149- await self . _server .connect ()
150- return await self . _server .get_prompt (name , arguments )
166+ await server .connect ()
167+ return await server .get_prompt (name , arguments )
151168 finally :
152- await self . _server .cleanup ()
169+ await server .cleanup ()
153170
154171 return list_tools , call_tool , list_prompts , get_prompt
155172
@@ -189,8 +206,9 @@ class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager):
189206 def __init__ (
190207 self ,
191208 server : str ,
192- config : Optional [ActivityConfig ] = None ,
193- server_session_config : Optional [ActivityConfig ] = None ,
209+ config : Optional [ActivityConfig ],
210+ server_session_config : Optional [ActivityConfig ],
211+ cache_tools_list : bool ,
194212 ):
195213 self ._name = server + "-stateful"
196214 self ._config = config or ActivityConfig (
@@ -201,6 +219,8 @@ def __init__(
201219 start_to_close_timeout = timedelta (hours = 1 ),
202220 )
203221 self ._connect_handle : Optional [ActivityHandle ] = None
222+ self ._cache_tools_list = cache_tools_list
223+ self ._tools = None
204224 super ().__init__ ()
205225
206226 @property
@@ -239,16 +259,22 @@ async def list_tools(
239259 run_context : Optional [RunContextWrapper [Any ]] = None ,
240260 agent : Optional [AgentBase ] = None ,
241261 ) -> list [MCPTool ]:
262+ if self ._tools :
263+ return self ._tools
264+
242265 if not self ._connect_handle :
243266 raise ApplicationError (
244267 "Stateful MCP Server not connected. Call connect first."
245268 )
246- return await workflow .execute_activity (
269+ tools = await workflow .execute_activity (
247270 self .name + "-list-tools" ,
248271 args = [],
249272 result_type = list [MCPTool ],
250273 ** self ._config ,
251274 )
275+ if self ._cache_tools_list :
276+ self ._tools = tools
277+ return tools
252278
253279 @_handle_worker_failure
254280 async def call_tool (
@@ -361,7 +387,7 @@ async def heartbeat_every(delay: float, *details: Any) -> None:
361387 await asyncio .sleep (delay )
362388 activity .heartbeat (* details )
363389
364- @activity .defn (name = self ._name + "-server-session" )
390+ @activity .defn (name = self .name + "-server-session" )
365391 async def connect () -> None :
366392 heartbeat_task = asyncio .create_task (heartbeat_every (30 ))
367393
0 commit comments