11import abc
22import asyncio
3+ import dataclasses
34import functools
5+ import inspect
46import logging
57from contextlib import AbstractAsyncContextManager
68from datetime import timedelta
7- from typing import Any , Callable , Optional , Sequence , Union
9+ from typing import Any , Callable , Optional , Sequence , Union , cast
810
911from agents import AgentBase , RunContextWrapper
1012from agents .mcp import MCPServer
2931logger = logging .getLogger (__name__ )
3032
3133
34+ @dataclasses .dataclass
35+ class _StatelessListToolsArguments :
36+ factory_argument : Optional [Any ]
37+
38+
39+ @dataclasses .dataclass
40+ class _StatelessCallToolsArguments :
41+ tool_name : str
42+ arguments : Optional [dict [str , Any ]]
43+ factory_argument : Optional [Any ]
44+
45+
46+ @dataclasses .dataclass
47+ class _StatelessListPromptsArguments :
48+ factory_argument : Optional [Any ]
49+
50+
51+ @dataclasses .dataclass
52+ class _StatelessGetPromptArguments :
53+ name : str
54+ arguments : Optional [dict [str , Any ]]
55+ factory_argument : Optional [Any ]
56+
57+
3258class _StatelessMCPServerReference (MCPServer ):
3359 def __init__ (
3460 self ,
3561 server : str ,
3662 config : Optional [ActivityConfig ],
3763 cache_tools_list : bool ,
64+ factory_argument : Optional [Any ] = None ,
3865 ):
3966 self ._name = server + "-stateless"
4067 self ._config = config or ActivityConfig (
4168 start_to_close_timeout = timedelta (minutes = 1 )
4269 )
4370 self ._cache_tools_list = cache_tools_list
4471 self ._tools = None
72+ self ._factory_argument = factory_argument
4573 super ().__init__ ()
4674
4775 @property
@@ -63,7 +91,7 @@ async def list_tools(
6391 return self ._tools
6492 tools = await workflow .execute_activity (
6593 self .name + "-list-tools" ,
66- args = [] ,
94+ _StatelessListToolsArguments ( self . _factory_argument ) ,
6795 result_type = list [MCPTool ],
6896 ** self ._config ,
6997 )
@@ -75,16 +103,16 @@ async def call_tool(
75103 self , tool_name : str , arguments : Optional [dict [str , Any ]]
76104 ) -> CallToolResult :
77105 return await workflow .execute_activity (
78- self .name + "-call-tool" ,
79- args = [ tool_name , arguments ] ,
106+ self .name + "-call-tool-v2 " ,
107+ _StatelessCallToolsArguments ( tool_name , arguments , self . _factory_argument ) ,
80108 result_type = CallToolResult ,
81109 ** self ._config ,
82110 )
83111
84112 async def list_prompts (self ) -> ListPromptsResult :
85113 return await workflow .execute_activity (
86114 self .name + "-list-prompts" ,
87- args = [] ,
115+ _StatelessListPromptsArguments ( self . _factory_argument ) ,
88116 result_type = ListPromptsResult ,
89117 ** self ._config ,
90118 )
@@ -93,8 +121,8 @@ async def get_prompt(
93121 self , name : str , arguments : Optional [dict [str , Any ]] = None
94122 ) -> GetPromptResult :
95123 return await workflow .execute_activity (
96- self .name + "-get-prompt" ,
97- args = [ name , arguments ] ,
124+ self .name + "-get-prompt-v2 " ,
125+ _StatelessGetPromptArguments ( name , arguments , self . _factory_argument ) ,
98126 result_type = GetPromptResult ,
99127 ** self ._config ,
100128 )
@@ -111,64 +139,107 @@ class StatelessMCPServerProvider:
111139 function, this cannot be used.
112140 """
113141
114- def __init__ (self , server_factory : Callable [[], MCPServer ]):
142+ def __init__ (
143+ self ,
144+ name : str ,
145+ server_factory : Union [
146+ Callable [[], MCPServer ], Callable [[Optional [Any ]], MCPServer ]
147+ ],
148+ ):
115149 """Initialize the stateless temporal MCP server.
116150
117151 Args:
152+ name: The name of the MCP server.
118153 server_factory: A function which will produce MCPServer instances. It should return a new server each time
119- so that state is not shared between workflow runs
154+ so that state is not shared between workflow runs.
120155 """
121156 self ._server_factory = server_factory
122- self ._name = server_factory ().name + "-stateless"
157+
158+ # Cache whether the server factory needs to be provided with arguments
159+ sig = inspect .signature (self ._server_factory )
160+ self ._server_accepts_arguments = len (sig .parameters ) != 0
161+
162+ self ._name = name + "-stateless"
123163 super ().__init__ ()
124164
165+ def _create_server (self , factory_argument : Optional [Any ]) -> MCPServer :
166+ if self ._server_accepts_arguments :
167+ return cast (Callable [[Optional [Any ]], MCPServer ], self ._server_factory )(
168+ factory_argument
169+ )
170+ else :
171+ return cast (Callable [[], MCPServer ], self ._server_factory )()
172+
125173 @property
126174 def name (self ) -> str :
127175 """Get the server name."""
128176 return self ._name
129177
130178 def _get_activities (self ) -> Sequence [Callable ]:
131179 @activity .defn (name = self .name + "-list-tools" )
132- async def list_tools () -> list [MCPTool ]:
133- server = self ._server_factory ()
180+ async def list_tools (
181+ args : Optional [_StatelessListToolsArguments ] = None ,
182+ ) -> list [MCPTool ]:
183+ server = self ._create_server (args .factory_argument if args else None )
134184 try :
135185 await server .connect ()
136186 return await server .list_tools ()
137187 finally :
138188 await server .cleanup ()
139189
140- @activity .defn (name = self .name + "-call-tool" )
141- async def call_tool (
142- tool_name : str , arguments : Optional [dict [str , Any ]]
143- ) -> CallToolResult :
144- server = self ._server_factory ()
190+ @activity .defn (name = self .name + "-call-tool-v2" )
191+ async def call_tool (args : _StatelessCallToolsArguments ) -> CallToolResult :
192+ server = self ._create_server (args .factory_argument )
145193 try :
146194 await server .connect ()
147- return await server .call_tool (tool_name , arguments )
195+ return await server .call_tool (args . tool_name , args . arguments )
148196 finally :
149197 await server .cleanup ()
150198
151199 @activity .defn (name = self .name + "-list-prompts" )
152- async def list_prompts () -> ListPromptsResult :
153- server = self ._server_factory ()
200+ async def list_prompts (
201+ args : Optional [_StatelessListPromptsArguments ] = None ,
202+ ) -> ListPromptsResult :
203+ server = self ._create_server (args .factory_argument if args else None )
154204 try :
155205 await server .connect ()
156206 return await server .list_prompts ()
157207 finally :
158208 await server .cleanup ()
159209
160- @activity .defn (name = self .name + "-get-prompt" )
161- async def get_prompt (
162- name : str , arguments : Optional [dict [str , Any ]]
163- ) -> GetPromptResult :
164- server = self ._server_factory ()
210+ @activity .defn (name = self .name + "-get-prompt-v2" )
211+ async def get_prompt (args : _StatelessGetPromptArguments ) -> GetPromptResult :
212+ server = self ._create_server (args .factory_argument )
165213 try :
166214 await server .connect ()
167- return await server .get_prompt (name , arguments )
215+ return await server .get_prompt (args . name , args . arguments )
168216 finally :
169217 await server .cleanup ()
170218
171- return list_tools , call_tool , list_prompts , get_prompt
219+ @activity .defn (name = self .name + "-call-tool" )
220+ async def call_tool_deprecated (
221+ tool_name : str ,
222+ arguments : Optional [dict [str , Any ]],
223+ ) -> CallToolResult :
224+ return await call_tool (
225+ _StatelessCallToolsArguments (tool_name , arguments , None )
226+ )
227+
228+ @activity .defn (name = self .name + "-get-prompt" )
229+ async def get_prompt_deprecated (
230+ name : str ,
231+ arguments : Optional [dict [str , Any ]],
232+ ) -> GetPromptResult :
233+ return await get_prompt (_StatelessGetPromptArguments (name , arguments , None ))
234+
235+ return (
236+ list_tools ,
237+ call_tool ,
238+ list_prompts ,
239+ get_prompt ,
240+ call_tool_deprecated ,
241+ get_prompt_deprecated ,
242+ )
172243
173244
174245def _handle_worker_failure (func ):
@@ -202,12 +273,30 @@ async def wrapper(*args, **kwargs):
202273 return wrapper
203274
204275
276+ @dataclasses .dataclass
277+ class _StatefulCallToolsArguments :
278+ tool_name : str
279+ arguments : Optional [dict [str , Any ]]
280+
281+
282+ @dataclasses .dataclass
283+ class _StatefulGetPromptArguments :
284+ name : str
285+ arguments : Optional [dict [str , Any ]]
286+
287+
288+ @dataclasses .dataclass
289+ class _StatefulServerSessionArguments :
290+ factory_argument : Optional [Any ]
291+
292+
205293class _StatefulMCPServerReference (MCPServer , AbstractAsyncContextManager ):
206294 def __init__ (
207295 self ,
208296 server : str ,
209297 config : Optional [ActivityConfig ],
210298 server_session_config : Optional [ActivityConfig ],
299+ factory_argument : Optional [Any ],
211300 ):
212301 self ._name = server + "-stateful"
213302 self ._config = config or ActivityConfig (
@@ -218,6 +307,7 @@ def __init__(
218307 start_to_close_timeout = timedelta (hours = 1 ),
219308 )
220309 self ._connect_handle : Optional [ActivityHandle ] = None
310+ self ._factory_argument = factory_argument
221311 super ().__init__ ()
222312
223313 @property
@@ -228,7 +318,7 @@ async def connect(self) -> None:
228318 self ._config ["task_queue" ] = self .name + "@" + workflow .info ().run_id
229319 self ._connect_handle = workflow .start_activity (
230320 self .name + "-server-session" ,
231- args = [] ,
321+ _StatefulServerSessionArguments ( self . _factory_argument ) ,
232322 ** self ._server_session_config ,
233323 )
234324
@@ -276,8 +366,8 @@ async def call_tool(
276366 "Stateful MCP Server not connected. Call connect first."
277367 )
278368 return await workflow .execute_activity (
279- self .name + "-call-tool" ,
280- args = [ tool_name , arguments ] ,
369+ self .name + "-call-tool-v2 " ,
370+ _StatefulCallToolsArguments ( tool_name , arguments ) ,
281371 result_type = CallToolResult ,
282372 ** self ._config ,
283373 )
@@ -304,8 +394,8 @@ async def get_prompt(
304394 "Stateful MCP Server not connected. Call connect first."
305395 )
306396 return await workflow .execute_activity (
307- self .name + "-get-prompt" ,
308- args = [ name , arguments ] ,
397+ self .name + "-get-prompt-v2 " ,
398+ _StatefulGetPromptArguments ( name , arguments ) ,
309399 result_type = GetPromptResult ,
310400 ** self ._config ,
311401 )
@@ -329,16 +419,18 @@ class StatefulMCPServerProvider:
329419
330420 def __init__ (
331421 self ,
332- server_factory : Callable [[], MCPServer ],
422+ name : str ,
423+ server_factory : Callable [[Optional [Any ]], MCPServer ],
333424 ):
334425 """Initialize the stateful temporal MCP server.
335426
336427 Args:
428+ name: The name of the MCP server.
337429 server_factory: A function which will produce MCPServer instances. It should return a new server each time
338430 so that state is not shared between workflow runs
339431 """
340432 self ._server_factory = server_factory
341- self ._name = server_factory (). name + "-stateful"
433+ self ._name = name + "-stateful"
342434 self ._connect_handle : Optional [ActivityHandle ] = None
343435 self ._servers : dict [str , MCPServer ] = {}
344436 super ().__init__ ()
@@ -357,37 +449,51 @@ async def list_tools() -> list[MCPTool]:
357449 return await self ._servers [_server_id ()].list_tools ()
358450
359451 @activity .defn (name = self .name + "-call-tool" )
360- async def call_tool (
452+ async def call_tool_deprecated (
361453 tool_name : str , arguments : Optional [dict [str , Any ]]
362454 ) -> CallToolResult :
363455 return await self ._servers [_server_id ()].call_tool (tool_name , arguments )
364456
457+ @activity .defn (name = self .name + "-call-tool-v2" )
458+ async def call_tool (args : _StatefulCallToolsArguments ) -> CallToolResult :
459+ return await self ._servers [_server_id ()].call_tool (
460+ args .tool_name , args .arguments
461+ )
462+
365463 @activity .defn (name = self .name + "-list-prompts" )
366464 async def list_prompts () -> ListPromptsResult :
367465 return await self ._servers [_server_id ()].list_prompts ()
368466
369467 @activity .defn (name = self .name + "-get-prompt" )
370- async def get_prompt (
468+ async def get_prompt_deprecated (
371469 name : str , arguments : Optional [dict [str , Any ]]
372470 ) -> GetPromptResult :
373471 return await self ._servers [_server_id ()].get_prompt (name , arguments )
374472
473+ @activity .defn (name = self .name + "-get-prompt-v2" )
474+ async def get_prompt (args : _StatefulGetPromptArguments ) -> GetPromptResult :
475+ return await self ._servers [_server_id ()].get_prompt (
476+ args .name , args .arguments
477+ )
478+
375479 async def heartbeat_every (delay : float , * details : Any ) -> None :
376480 """Heartbeat every so often while not cancelled"""
377481 while True :
378482 await asyncio .sleep (delay )
379483 activity .heartbeat (* details )
380484
381485 @activity .defn (name = self .name + "-server-session" )
382- async def connect () -> None :
486+ async def connect (
487+ args : Optional [_StatefulServerSessionArguments ] = None ,
488+ ) -> None :
383489 heartbeat_task = asyncio .create_task (heartbeat_every (30 ))
384490
385491 server_id = self .name + "@" + activity .info ().workflow_run_id
386492 if server_id in self ._servers :
387493 raise ApplicationError (
388494 "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow."
389495 )
390- server = self ._server_factory ()
496+ server = self ._server_factory (args . factory_argument if args else None )
391497 try :
392498 self ._servers [server_id ] = server
393499 try :
@@ -396,7 +502,14 @@ async def connect() -> None:
396502 worker = Worker (
397503 activity .client (),
398504 task_queue = server_id ,
399- activities = [list_tools , call_tool , list_prompts , get_prompt ],
505+ activities = [
506+ list_tools ,
507+ call_tool ,
508+ list_prompts ,
509+ get_prompt ,
510+ call_tool_deprecated ,
511+ get_prompt_deprecated ,
512+ ],
400513 activity_task_poller_behavior = PollerBehaviorSimpleMaximum (1 ),
401514 )
402515
0 commit comments