Skip to content

Commit 162cff7

Browse files
authored
💥 Allow arguments to be provided to MCP Server creation (#1147)
* Add replay tests for MCP * Remove history generation * Stateful too * Deprecate exisitng activities, add dataclasses * Clean up imports * Fix tests * Cache server function needing args * Add name argument * Tighten test assertions
1 parent 70deb94 commit 162cff7

File tree

4 files changed

+267
-77
lines changed

4 files changed

+267
-77
lines changed

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 152 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import abc
22
import asyncio
3+
import dataclasses
34
import functools
5+
import inspect
46
import logging
57
from contextlib import AbstractAsyncContextManager
68
from datetime import timedelta
7-
from typing import Any, Callable, Optional, Sequence, Union
9+
from typing import Any, Callable, Optional, Sequence, Union, cast
810

911
from agents import AgentBase, RunContextWrapper
1012
from agents.mcp import MCPServer
@@ -29,19 +31,45 @@
2931
logger = logging.getLogger(__name__)
3032

3133

34+
@dataclasses.dataclass
35+
class _StatelessListToolsArguments:
36+
factory_argument: Optional[Any]
37+
38+
39+
@dataclasses.dataclass
40+
class _StatelessCallToolsArguments:
41+
tool_name: str
42+
arguments: Optional[dict[str, Any]]
43+
factory_argument: Optional[Any]
44+
45+
46+
@dataclasses.dataclass
47+
class _StatelessListPromptsArguments:
48+
factory_argument: Optional[Any]
49+
50+
51+
@dataclasses.dataclass
52+
class _StatelessGetPromptArguments:
53+
name: str
54+
arguments: Optional[dict[str, Any]]
55+
factory_argument: Optional[Any]
56+
57+
3258
class _StatelessMCPServerReference(MCPServer):
3359
def __init__(
3460
self,
3561
server: str,
3662
config: Optional[ActivityConfig],
3763
cache_tools_list: bool,
64+
factory_argument: Optional[Any] = None,
3865
):
3966
self._name = server + "-stateless"
4067
self._config = config or ActivityConfig(
4168
start_to_close_timeout=timedelta(minutes=1)
4269
)
4370
self._cache_tools_list = cache_tools_list
4471
self._tools = None
72+
self._factory_argument = factory_argument
4573
super().__init__()
4674

4775
@property
@@ -63,7 +91,7 @@ async def list_tools(
6391
return self._tools
6492
tools = await workflow.execute_activity(
6593
self.name + "-list-tools",
66-
args=[],
94+
_StatelessListToolsArguments(self._factory_argument),
6795
result_type=list[MCPTool],
6896
**self._config,
6997
)
@@ -75,16 +103,16 @@ async def call_tool(
75103
self, tool_name: str, arguments: Optional[dict[str, Any]]
76104
) -> CallToolResult:
77105
return await workflow.execute_activity(
78-
self.name + "-call-tool",
79-
args=[tool_name, arguments],
106+
self.name + "-call-tool-v2",
107+
_StatelessCallToolsArguments(tool_name, arguments, self._factory_argument),
80108
result_type=CallToolResult,
81109
**self._config,
82110
)
83111

84112
async def list_prompts(self) -> ListPromptsResult:
85113
return await workflow.execute_activity(
86114
self.name + "-list-prompts",
87-
args=[],
115+
_StatelessListPromptsArguments(self._factory_argument),
88116
result_type=ListPromptsResult,
89117
**self._config,
90118
)
@@ -93,8 +121,8 @@ async def get_prompt(
93121
self, name: str, arguments: Optional[dict[str, Any]] = None
94122
) -> GetPromptResult:
95123
return await workflow.execute_activity(
96-
self.name + "-get-prompt",
97-
args=[name, arguments],
124+
self.name + "-get-prompt-v2",
125+
_StatelessGetPromptArguments(name, arguments, self._factory_argument),
98126
result_type=GetPromptResult,
99127
**self._config,
100128
)
@@ -111,64 +139,107 @@ class StatelessMCPServerProvider:
111139
function, this cannot be used.
112140
"""
113141

114-
def __init__(self, server_factory: Callable[[], MCPServer]):
142+
def __init__(
143+
self,
144+
name: str,
145+
server_factory: Union[
146+
Callable[[], MCPServer], Callable[[Optional[Any]], MCPServer]
147+
],
148+
):
115149
"""Initialize the stateless temporal MCP server.
116150
117151
Args:
152+
name: The name of the MCP server.
118153
server_factory: A function which will produce MCPServer instances. It should return a new server each time
119-
so that state is not shared between workflow runs
154+
so that state is not shared between workflow runs.
120155
"""
121156
self._server_factory = server_factory
122-
self._name = server_factory().name + "-stateless"
157+
158+
# Cache whether the server factory needs to be provided with arguments
159+
sig = inspect.signature(self._server_factory)
160+
self._server_accepts_arguments = len(sig.parameters) != 0
161+
162+
self._name = name + "-stateless"
123163
super().__init__()
124164

165+
def _create_server(self, factory_argument: Optional[Any]) -> MCPServer:
166+
if self._server_accepts_arguments:
167+
return cast(Callable[[Optional[Any]], MCPServer], self._server_factory)(
168+
factory_argument
169+
)
170+
else:
171+
return cast(Callable[[], MCPServer], self._server_factory)()
172+
125173
@property
126174
def name(self) -> str:
127175
"""Get the server name."""
128176
return self._name
129177

130178
def _get_activities(self) -> Sequence[Callable]:
131179
@activity.defn(name=self.name + "-list-tools")
132-
async def list_tools() -> list[MCPTool]:
133-
server = self._server_factory()
180+
async def list_tools(
181+
args: Optional[_StatelessListToolsArguments] = None,
182+
) -> list[MCPTool]:
183+
server = self._create_server(args.factory_argument if args else None)
134184
try:
135185
await server.connect()
136186
return await server.list_tools()
137187
finally:
138188
await server.cleanup()
139189

140-
@activity.defn(name=self.name + "-call-tool")
141-
async def call_tool(
142-
tool_name: str, arguments: Optional[dict[str, Any]]
143-
) -> CallToolResult:
144-
server = self._server_factory()
190+
@activity.defn(name=self.name + "-call-tool-v2")
191+
async def call_tool(args: _StatelessCallToolsArguments) -> CallToolResult:
192+
server = self._create_server(args.factory_argument)
145193
try:
146194
await server.connect()
147-
return await server.call_tool(tool_name, arguments)
195+
return await server.call_tool(args.tool_name, args.arguments)
148196
finally:
149197
await server.cleanup()
150198

151199
@activity.defn(name=self.name + "-list-prompts")
152-
async def list_prompts() -> ListPromptsResult:
153-
server = self._server_factory()
200+
async def list_prompts(
201+
args: Optional[_StatelessListPromptsArguments] = None,
202+
) -> ListPromptsResult:
203+
server = self._create_server(args.factory_argument if args else None)
154204
try:
155205
await server.connect()
156206
return await server.list_prompts()
157207
finally:
158208
await server.cleanup()
159209

160-
@activity.defn(name=self.name + "-get-prompt")
161-
async def get_prompt(
162-
name: str, arguments: Optional[dict[str, Any]]
163-
) -> GetPromptResult:
164-
server = self._server_factory()
210+
@activity.defn(name=self.name + "-get-prompt-v2")
211+
async def get_prompt(args: _StatelessGetPromptArguments) -> GetPromptResult:
212+
server = self._create_server(args.factory_argument)
165213
try:
166214
await server.connect()
167-
return await server.get_prompt(name, arguments)
215+
return await server.get_prompt(args.name, args.arguments)
168216
finally:
169217
await server.cleanup()
170218

171-
return list_tools, call_tool, list_prompts, get_prompt
219+
@activity.defn(name=self.name + "-call-tool")
220+
async def call_tool_deprecated(
221+
tool_name: str,
222+
arguments: Optional[dict[str, Any]],
223+
) -> CallToolResult:
224+
return await call_tool(
225+
_StatelessCallToolsArguments(tool_name, arguments, None)
226+
)
227+
228+
@activity.defn(name=self.name + "-get-prompt")
229+
async def get_prompt_deprecated(
230+
name: str,
231+
arguments: Optional[dict[str, Any]],
232+
) -> GetPromptResult:
233+
return await get_prompt(_StatelessGetPromptArguments(name, arguments, None))
234+
235+
return (
236+
list_tools,
237+
call_tool,
238+
list_prompts,
239+
get_prompt,
240+
call_tool_deprecated,
241+
get_prompt_deprecated,
242+
)
172243

173244

174245
def _handle_worker_failure(func):
@@ -202,12 +273,30 @@ async def wrapper(*args, **kwargs):
202273
return wrapper
203274

204275

276+
@dataclasses.dataclass
277+
class _StatefulCallToolsArguments:
278+
tool_name: str
279+
arguments: Optional[dict[str, Any]]
280+
281+
282+
@dataclasses.dataclass
283+
class _StatefulGetPromptArguments:
284+
name: str
285+
arguments: Optional[dict[str, Any]]
286+
287+
288+
@dataclasses.dataclass
289+
class _StatefulServerSessionArguments:
290+
factory_argument: Optional[Any]
291+
292+
205293
class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager):
206294
def __init__(
207295
self,
208296
server: str,
209297
config: Optional[ActivityConfig],
210298
server_session_config: Optional[ActivityConfig],
299+
factory_argument: Optional[Any],
211300
):
212301
self._name = server + "-stateful"
213302
self._config = config or ActivityConfig(
@@ -218,6 +307,7 @@ def __init__(
218307
start_to_close_timeout=timedelta(hours=1),
219308
)
220309
self._connect_handle: Optional[ActivityHandle] = None
310+
self._factory_argument = factory_argument
221311
super().__init__()
222312

223313
@property
@@ -228,7 +318,7 @@ async def connect(self) -> None:
228318
self._config["task_queue"] = self.name + "@" + workflow.info().run_id
229319
self._connect_handle = workflow.start_activity(
230320
self.name + "-server-session",
231-
args=[],
321+
_StatefulServerSessionArguments(self._factory_argument),
232322
**self._server_session_config,
233323
)
234324

@@ -276,8 +366,8 @@ async def call_tool(
276366
"Stateful MCP Server not connected. Call connect first."
277367
)
278368
return await workflow.execute_activity(
279-
self.name + "-call-tool",
280-
args=[tool_name, arguments],
369+
self.name + "-call-tool-v2",
370+
_StatefulCallToolsArguments(tool_name, arguments),
281371
result_type=CallToolResult,
282372
**self._config,
283373
)
@@ -304,8 +394,8 @@ async def get_prompt(
304394
"Stateful MCP Server not connected. Call connect first."
305395
)
306396
return await workflow.execute_activity(
307-
self.name + "-get-prompt",
308-
args=[name, arguments],
397+
self.name + "-get-prompt-v2",
398+
_StatefulGetPromptArguments(name, arguments),
309399
result_type=GetPromptResult,
310400
**self._config,
311401
)
@@ -329,16 +419,18 @@ class StatefulMCPServerProvider:
329419

330420
def __init__(
331421
self,
332-
server_factory: Callable[[], MCPServer],
422+
name: str,
423+
server_factory: Callable[[Optional[Any]], MCPServer],
333424
):
334425
"""Initialize the stateful temporal MCP server.
335426
336427
Args:
428+
name: The name of the MCP server.
337429
server_factory: A function which will produce MCPServer instances. It should return a new server each time
338430
so that state is not shared between workflow runs
339431
"""
340432
self._server_factory = server_factory
341-
self._name = server_factory().name + "-stateful"
433+
self._name = name + "-stateful"
342434
self._connect_handle: Optional[ActivityHandle] = None
343435
self._servers: dict[str, MCPServer] = {}
344436
super().__init__()
@@ -357,37 +449,51 @@ async def list_tools() -> list[MCPTool]:
357449
return await self._servers[_server_id()].list_tools()
358450

359451
@activity.defn(name=self.name + "-call-tool")
360-
async def call_tool(
452+
async def call_tool_deprecated(
361453
tool_name: str, arguments: Optional[dict[str, Any]]
362454
) -> CallToolResult:
363455
return await self._servers[_server_id()].call_tool(tool_name, arguments)
364456

457+
@activity.defn(name=self.name + "-call-tool-v2")
458+
async def call_tool(args: _StatefulCallToolsArguments) -> CallToolResult:
459+
return await self._servers[_server_id()].call_tool(
460+
args.tool_name, args.arguments
461+
)
462+
365463
@activity.defn(name=self.name + "-list-prompts")
366464
async def list_prompts() -> ListPromptsResult:
367465
return await self._servers[_server_id()].list_prompts()
368466

369467
@activity.defn(name=self.name + "-get-prompt")
370-
async def get_prompt(
468+
async def get_prompt_deprecated(
371469
name: str, arguments: Optional[dict[str, Any]]
372470
) -> GetPromptResult:
373471
return await self._servers[_server_id()].get_prompt(name, arguments)
374472

473+
@activity.defn(name=self.name + "-get-prompt-v2")
474+
async def get_prompt(args: _StatefulGetPromptArguments) -> GetPromptResult:
475+
return await self._servers[_server_id()].get_prompt(
476+
args.name, args.arguments
477+
)
478+
375479
async def heartbeat_every(delay: float, *details: Any) -> None:
376480
"""Heartbeat every so often while not cancelled"""
377481
while True:
378482
await asyncio.sleep(delay)
379483
activity.heartbeat(*details)
380484

381485
@activity.defn(name=self.name + "-server-session")
382-
async def connect() -> None:
486+
async def connect(
487+
args: Optional[_StatefulServerSessionArguments] = None,
488+
) -> None:
383489
heartbeat_task = asyncio.create_task(heartbeat_every(30))
384490

385491
server_id = self.name + "@" + activity.info().workflow_run_id
386492
if server_id in self._servers:
387493
raise ApplicationError(
388494
"Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow."
389495
)
390-
server = self._server_factory()
496+
server = self._server_factory(args.factory_argument if args else None)
391497
try:
392498
self._servers[server_id] = server
393499
try:
@@ -396,7 +502,14 @@ async def connect() -> None:
396502
worker = Worker(
397503
activity.client(),
398504
task_queue=server_id,
399-
activities=[list_tools, call_tool, list_prompts, get_prompt],
505+
activities=[
506+
list_tools,
507+
call_tool,
508+
list_prompts,
509+
get_prompt,
510+
call_tool_deprecated,
511+
get_prompt_deprecated,
512+
],
400513
activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1),
401514
)
402515

0 commit comments

Comments
 (0)