diff --git a/temporalio/contrib/openai_agents/_mcp.py b/temporalio/contrib/openai_agents/_mcp.py index e41cc2d3a..07ae8d838 100644 --- a/temporalio/contrib/openai_agents/_mcp.py +++ b/temporalio/contrib/openai_agents/_mcp.py @@ -1,10 +1,12 @@ import abc import asyncio +import dataclasses import functools +import inspect import logging from contextlib import AbstractAsyncContextManager from datetime import timedelta -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union, cast from agents import AgentBase, RunContextWrapper from agents.mcp import MCPServer @@ -29,12 +31,37 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class _StatelessListToolsArguments: + factory_argument: Optional[Any] + + +@dataclasses.dataclass +class _StatelessCallToolsArguments: + tool_name: str + arguments: Optional[dict[str, Any]] + factory_argument: Optional[Any] + + +@dataclasses.dataclass +class _StatelessListPromptsArguments: + factory_argument: Optional[Any] + + +@dataclasses.dataclass +class _StatelessGetPromptArguments: + name: str + arguments: Optional[dict[str, Any]] + factory_argument: Optional[Any] + + class _StatelessMCPServerReference(MCPServer): def __init__( self, server: str, config: Optional[ActivityConfig], cache_tools_list: bool, + factory_argument: Optional[Any] = None, ): self._name = server + "-stateless" self._config = config or ActivityConfig( @@ -42,6 +69,7 @@ def __init__( ) self._cache_tools_list = cache_tools_list self._tools = None + self._factory_argument = factory_argument super().__init__() @property @@ -63,7 +91,7 @@ async def list_tools( return self._tools tools = await workflow.execute_activity( self.name + "-list-tools", - args=[], + _StatelessListToolsArguments(self._factory_argument), result_type=list[MCPTool], **self._config, ) @@ -75,8 +103,8 @@ async def call_tool( self, tool_name: str, arguments: Optional[dict[str, Any]] ) -> CallToolResult: return await workflow.execute_activity( - self.name + "-call-tool", - args=[tool_name, arguments], + self.name + "-call-tool-v2", + _StatelessCallToolsArguments(tool_name, arguments, self._factory_argument), result_type=CallToolResult, **self._config, ) @@ -84,7 +112,7 @@ async def call_tool( async def list_prompts(self) -> ListPromptsResult: return await workflow.execute_activity( self.name + "-list-prompts", - args=[], + _StatelessListPromptsArguments(self._factory_argument), result_type=ListPromptsResult, **self._config, ) @@ -93,8 +121,8 @@ async def get_prompt( self, name: str, arguments: Optional[dict[str, Any]] = None ) -> GetPromptResult: return await workflow.execute_activity( - self.name + "-get-prompt", - args=[name, arguments], + self.name + "-get-prompt-v2", + _StatelessGetPromptArguments(name, arguments, self._factory_argument), result_type=GetPromptResult, **self._config, ) @@ -111,17 +139,37 @@ class StatelessMCPServerProvider: function, this cannot be used. """ - def __init__(self, server_factory: Callable[[], MCPServer]): + def __init__( + self, + name: str, + server_factory: Union[ + Callable[[], MCPServer], Callable[[Optional[Any]], MCPServer] + ], + ): """Initialize the stateless temporal MCP server. Args: + name: The name of the MCP server. server_factory: A function which will produce MCPServer instances. It should return a new server each time - so that state is not shared between workflow runs + so that state is not shared between workflow runs. """ self._server_factory = server_factory - self._name = server_factory().name + "-stateless" + + # Cache whether the server factory needs to be provided with arguments + sig = inspect.signature(self._server_factory) + self._server_accepts_arguments = len(sig.parameters) != 0 + + self._name = name + "-stateless" super().__init__() + def _create_server(self, factory_argument: Optional[Any]) -> MCPServer: + if self._server_accepts_arguments: + return cast(Callable[[Optional[Any]], MCPServer], self._server_factory)( + factory_argument + ) + else: + return cast(Callable[[], MCPServer], self._server_factory)() + @property def name(self) -> str: """Get the server name.""" @@ -129,46 +177,69 @@ def name(self) -> str: def _get_activities(self) -> Sequence[Callable]: @activity.defn(name=self.name + "-list-tools") - async def list_tools() -> list[MCPTool]: - server = self._server_factory() + async def list_tools( + args: Optional[_StatelessListToolsArguments] = None, + ) -> list[MCPTool]: + server = self._create_server(args.factory_argument if args else None) try: await server.connect() return await server.list_tools() finally: await server.cleanup() - @activity.defn(name=self.name + "-call-tool") - async def call_tool( - tool_name: str, arguments: Optional[dict[str, Any]] - ) -> CallToolResult: - server = self._server_factory() + @activity.defn(name=self.name + "-call-tool-v2") + async def call_tool(args: _StatelessCallToolsArguments) -> CallToolResult: + server = self._create_server(args.factory_argument) try: await server.connect() - return await server.call_tool(tool_name, arguments) + return await server.call_tool(args.tool_name, args.arguments) finally: await server.cleanup() @activity.defn(name=self.name + "-list-prompts") - async def list_prompts() -> ListPromptsResult: - server = self._server_factory() + async def list_prompts( + args: Optional[_StatelessListPromptsArguments] = None, + ) -> ListPromptsResult: + server = self._create_server(args.factory_argument if args else None) try: await server.connect() return await server.list_prompts() finally: await server.cleanup() - @activity.defn(name=self.name + "-get-prompt") - async def get_prompt( - name: str, arguments: Optional[dict[str, Any]] - ) -> GetPromptResult: - server = self._server_factory() + @activity.defn(name=self.name + "-get-prompt-v2") + async def get_prompt(args: _StatelessGetPromptArguments) -> GetPromptResult: + server = self._create_server(args.factory_argument) try: await server.connect() - return await server.get_prompt(name, arguments) + return await server.get_prompt(args.name, args.arguments) finally: await server.cleanup() - return list_tools, call_tool, list_prompts, get_prompt + @activity.defn(name=self.name + "-call-tool") + async def call_tool_deprecated( + tool_name: str, + arguments: Optional[dict[str, Any]], + ) -> CallToolResult: + return await call_tool( + _StatelessCallToolsArguments(tool_name, arguments, None) + ) + + @activity.defn(name=self.name + "-get-prompt") + async def get_prompt_deprecated( + name: str, + arguments: Optional[dict[str, Any]], + ) -> GetPromptResult: + return await get_prompt(_StatelessGetPromptArguments(name, arguments, None)) + + return ( + list_tools, + call_tool, + list_prompts, + get_prompt, + call_tool_deprecated, + get_prompt_deprecated, + ) def _handle_worker_failure(func): @@ -202,12 +273,30 @@ async def wrapper(*args, **kwargs): return wrapper +@dataclasses.dataclass +class _StatefulCallToolsArguments: + tool_name: str + arguments: Optional[dict[str, Any]] + + +@dataclasses.dataclass +class _StatefulGetPromptArguments: + name: str + arguments: Optional[dict[str, Any]] + + +@dataclasses.dataclass +class _StatefulServerSessionArguments: + factory_argument: Optional[Any] + + class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager): def __init__( self, server: str, config: Optional[ActivityConfig], server_session_config: Optional[ActivityConfig], + factory_argument: Optional[Any], ): self._name = server + "-stateful" self._config = config or ActivityConfig( @@ -218,6 +307,7 @@ def __init__( start_to_close_timeout=timedelta(hours=1), ) self._connect_handle: Optional[ActivityHandle] = None + self._factory_argument = factory_argument super().__init__() @property @@ -228,7 +318,7 @@ async def connect(self) -> None: self._config["task_queue"] = self.name + "@" + workflow.info().run_id self._connect_handle = workflow.start_activity( self.name + "-server-session", - args=[], + _StatefulServerSessionArguments(self._factory_argument), **self._server_session_config, ) @@ -276,8 +366,8 @@ async def call_tool( "Stateful MCP Server not connected. Call connect first." ) return await workflow.execute_activity( - self.name + "-call-tool", - args=[tool_name, arguments], + self.name + "-call-tool-v2", + _StatefulCallToolsArguments(tool_name, arguments), result_type=CallToolResult, **self._config, ) @@ -304,8 +394,8 @@ async def get_prompt( "Stateful MCP Server not connected. Call connect first." ) return await workflow.execute_activity( - self.name + "-get-prompt", - args=[name, arguments], + self.name + "-get-prompt-v2", + _StatefulGetPromptArguments(name, arguments), result_type=GetPromptResult, **self._config, ) @@ -329,16 +419,18 @@ class StatefulMCPServerProvider: def __init__( self, - server_factory: Callable[[], MCPServer], + name: str, + server_factory: Callable[[Optional[Any]], MCPServer], ): """Initialize the stateful temporal MCP server. Args: + name: The name of the MCP server. server_factory: A function which will produce MCPServer instances. It should return a new server each time so that state is not shared between workflow runs """ self._server_factory = server_factory - self._name = server_factory().name + "-stateful" + self._name = name + "-stateful" self._connect_handle: Optional[ActivityHandle] = None self._servers: dict[str, MCPServer] = {} super().__init__() @@ -357,21 +449,33 @@ async def list_tools() -> list[MCPTool]: return await self._servers[_server_id()].list_tools() @activity.defn(name=self.name + "-call-tool") - async def call_tool( + async def call_tool_deprecated( tool_name: str, arguments: Optional[dict[str, Any]] ) -> CallToolResult: return await self._servers[_server_id()].call_tool(tool_name, arguments) + @activity.defn(name=self.name + "-call-tool-v2") + async def call_tool(args: _StatefulCallToolsArguments) -> CallToolResult: + return await self._servers[_server_id()].call_tool( + args.tool_name, args.arguments + ) + @activity.defn(name=self.name + "-list-prompts") async def list_prompts() -> ListPromptsResult: return await self._servers[_server_id()].list_prompts() @activity.defn(name=self.name + "-get-prompt") - async def get_prompt( + async def get_prompt_deprecated( name: str, arguments: Optional[dict[str, Any]] ) -> GetPromptResult: return await self._servers[_server_id()].get_prompt(name, arguments) + @activity.defn(name=self.name + "-get-prompt-v2") + async def get_prompt(args: _StatefulGetPromptArguments) -> GetPromptResult: + return await self._servers[_server_id()].get_prompt( + args.name, args.arguments + ) + async def heartbeat_every(delay: float, *details: Any) -> None: """Heartbeat every so often while not cancelled""" while True: @@ -379,7 +483,9 @@ async def heartbeat_every(delay: float, *details: Any) -> None: activity.heartbeat(*details) @activity.defn(name=self.name + "-server-session") - async def connect() -> None: + async def connect( + args: Optional[_StatefulServerSessionArguments] = None, + ) -> None: heartbeat_task = asyncio.create_task(heartbeat_every(30)) server_id = self.name + "@" + activity.info().workflow_run_id @@ -387,7 +493,7 @@ async def connect() -> None: raise ApplicationError( "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow." ) - server = self._server_factory() + server = self._server_factory(args.factory_argument if args else None) try: self._servers[server_id] = server try: @@ -396,7 +502,14 @@ async def connect() -> None: worker = Worker( activity.client(), task_queue=server_id, - activities=[list_tools, call_tool, list_prompts, get_prompt], + activities=[ + list_tools, + call_tool, + list_prompts, + get_prompt, + call_tool_deprecated, + get_prompt_deprecated, + ], activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1), ) diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index 51eff86a3..1023531a1 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -250,6 +250,7 @@ def stateless_mcp_server( name: str, config: Optional[ActivityConfig] = None, cache_tools_list: bool = False, + factory_argument: Optional[Any] = None, ) -> "MCPServer": """A stateless MCP server implementation for Temporal workflows. @@ -269,18 +270,22 @@ def stateless_mcp_server( config: Optional activity configuration for MCP operation activities. Defaults to 1-minute start-to-close timeout. cache_tools_list: If true, the list of tools will be cached for the duration of the server + factory_argument: Optional argument to be provided to the factory when producing an MCPServer """ from temporalio.contrib.openai_agents._mcp import ( _StatelessMCPServerReference, ) - return _StatelessMCPServerReference(name, config, cache_tools_list) + return _StatelessMCPServerReference( + name, config, cache_tools_list, factory_argument + ) def stateful_mcp_server( name: str, config: Optional[ActivityConfig] = None, server_session_config: Optional[ActivityConfig] = None, + factory_argument: Optional[Any] = None, ) -> AbstractAsyncContextManager["MCPServer"]: """A stateful MCP server implementation for Temporal workflows. @@ -305,12 +310,15 @@ def stateful_mcp_server( Defaults to 1-minute start-to-close and 30-second schedule-to-start timeouts. server_session_config: Optional activity configuration for the connection activity. Defaults to 1-hour start-to-close timeout. + factory_argument: Optional argument to be provided to the factory when producing an MCPServer """ from temporalio.contrib.openai_agents._mcp import ( _StatefulMCPServerReference, ) - return _StatefulMCPServerReference(name, config, server_session_config) + return _StatefulMCPServerReference( + name, config, server_session_config, factory_argument + ) class ToolSerializationError(TemporalError): diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index aa81e32c8..812731be2 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -108,7 +108,7 @@ from tests.contrib.openai_agents.research_agents.research_manager import ( ResearchManager, ) -from tests.helpers import assert_eventually, assert_task_fail_eventually, new_worker +from tests.helpers import assert_eventually, new_worker from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -2322,11 +2322,11 @@ async def test_output_type(client: Client): @workflow.defn class McpServerWorkflow: @workflow.run - async def run(self, caching: bool) -> str: + async def run(self, caching: bool, factory_argument: Optional[Any]) -> str: from agents.mcp import MCPServer server: MCPServer = openai_agents.workflow.stateless_mcp_server( - "HelloServer", cache_tools_list=caching + "HelloServer", cache_tools_list=caching, factory_argument=factory_argument ) agent = Agent[str]( name="MCP ServerWorkflow", @@ -2342,13 +2342,14 @@ async def run(self, caching: bool) -> str: @workflow.defn class McpServerStatefulWorkflow: @workflow.run - async def run(self, timeout: timedelta) -> str: + async def run(self, timeout: timedelta, factory_argument: Optional[Any]) -> str: async with openai_agents.workflow.stateful_mcp_server( "HelloServer", config=ActivityConfig( schedule_to_start_timeout=timeout, start_to_close_timeout=timedelta(seconds=30), ), + factory_argument=factory_argument, ) as server: agent = Agent[str]( name="MCP ServerWorkflow", @@ -2375,31 +2376,12 @@ class TrackingMCPModel(StaticTestModel): ] -@pytest.mark.parametrize("use_local_model", [True, False]) -@pytest.mark.parametrize("stateful", [True, False]) -@pytest.mark.parametrize("caching", [True, False]) -async def test_mcp_server( - client: Client, use_local_model: bool, stateful: bool, caching: bool -): - if not use_local_model and not os.environ.get("OPENAI_API_KEY"): - pytest.skip("No openai API key") - - if sys.version_info < (3, 10): - pytest.skip("Mcp not supported on Python 3.9") - - if stateful and caching: - pytest.skip("Caching is only supported for stateless MCP servers") - - from agents.mcp import MCPServer +def get_tracking_server(name: str): + from agents.mcp import MCPServer # type: ignore from mcp import GetPromptResult, ListPromptsResult # type: ignore from mcp import Tool as MCPTool # type: ignore from mcp.types import CallToolResult, TextContent # type: ignore - from temporalio.contrib.openai_agents import ( - StatefulMCPServerProvider, - StatelessMCPServerProvider, - ) - class TrackingMCPServer(MCPServer): calls: list[str] @@ -2455,11 +2437,36 @@ async def get_prompt( ) -> GetPromptResult: raise NotImplementedError() - tracking_server = TrackingMCPServer(name="HelloServer") + return TrackingMCPServer(name) + + +@pytest.mark.parametrize("use_local_model", [True, False]) +@pytest.mark.parametrize("stateful", [True, False]) +@pytest.mark.parametrize("caching", [True, False]) +async def test_mcp_server( + client: Client, use_local_model: bool, stateful: bool, caching: bool +): + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): + pytest.skip("No openai API key") + + if sys.version_info < (3, 10): + pytest.skip("Mcp not supported on Python 3.9") + + if stateful and caching: + pytest.skip("Caching is only supported for stateless MCP servers") + + from agents.mcp import MCPServer # type: ignore + + from temporalio.contrib.openai_agents import ( + StatefulMCPServerProvider, + StatelessMCPServerProvider, + ) + + tracking_server = get_tracking_server(name="HelloServer") server: Union[StatefulMCPServerProvider, StatelessMCPServerProvider] = ( - StatefulMCPServerProvider(lambda: tracking_server) + StatefulMCPServerProvider("HelloServer", lambda _: tracking_server) if stateful - else StatelessMCPServerProvider(lambda: tracking_server) + else StatelessMCPServerProvider("HelloServer", lambda _: tracking_server) ) new_config = client.config() @@ -2482,7 +2489,7 @@ async def get_prompt( if stateful: result = await client.execute_workflow( McpServerStatefulWorkflow.run, - args=[timedelta(seconds=30)], + args=[timedelta(seconds=30), None], id=f"mcp-server-{uuid.uuid4()}", task_queue=worker.task_queue, execution_timeout=timedelta(seconds=30), @@ -2490,7 +2497,7 @@ async def get_prompt( else: result = await client.execute_workflow( McpServerWorkflow.run, - args=[caching], + args=[caching, None], id=f"mcp-server-{uuid.uuid4()}", task_queue=worker.task_queue, execution_timeout=timedelta(seconds=30), @@ -2543,6 +2550,69 @@ async def get_prompt( ] +@pytest.mark.parametrize("stateful", [True, False]) +async def test_mcp_server_factory_argument(client: Client, stateful: bool): + if sys.version_info < (3, 10): + pytest.skip("Mcp not supported on Python 3.9") + + from agents.mcp import MCPServer # type: ignore + from mcp import GetPromptResult, ListPromptsResult # type: ignore + from mcp import Tool as MCPTool # type: ignore + from mcp.types import CallToolResult, TextContent # type: ignore + + from temporalio.contrib.openai_agents import ( + StatefulMCPServerProvider, + StatelessMCPServerProvider, + ) + + def factory(args: Optional[Any]) -> MCPServer: + print("Invoking factory: ", args) + if args is not None: + assert args is not None + assert cast(dict[str, str], args).get("user") == "blah" + + return get_tracking_server("HelloServer") + + server: Union[StatefulMCPServerProvider, StatelessMCPServerProvider] = ( + StatefulMCPServerProvider("HelloServer", factory) + if stateful + else StatelessMCPServerProvider("HelloServer", factory) + ) + + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=120) + ), + model_provider=TestModelProvider(TrackingMCPModel()), + mcp_server_providers=[server], + ) + ] + client = Client(**new_config) + + headers = {"user": "blah"} + async with new_worker( + client, McpServerStatefulWorkflow, McpServerWorkflow + ) as worker: + if stateful: + result = await client.execute_workflow( + McpServerStatefulWorkflow.run, + args=[timedelta(seconds=30), headers], + id=f"mcp-server-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + else: + result = await client.execute_workflow( + McpServerWorkflow.run, + args=[False, headers], + id=f"mcp-server-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + + async def test_stateful_mcp_server_no_worker(client: Client): if sys.version_info < (3, 10): pytest.skip("Mcp not supported on Python 3.9") @@ -2551,7 +2621,8 @@ async def test_stateful_mcp_server_no_worker(client: Client): from temporalio.contrib.openai_agents import StatefulMCPServerProvider server = StatefulMCPServerProvider( - lambda: MCPServerStdio( + "Filesystem-Server", + lambda _: MCPServerStdio( name="Filesystem-Server", params={ "command": "npx", @@ -2561,7 +2632,7 @@ async def test_stateful_mcp_server_no_worker(client: Client): os.path.dirname(os.path.abspath(__file__)), ], }, - ) + ), ) # Override the connect activity to not actually start a worker @@ -2592,7 +2663,7 @@ def override_get_activities() -> Sequence[Callable]: ) as worker: workflow_handle = await client.start_workflow( McpServerStatefulWorkflow.run, - args=[timedelta(seconds=1)], + args=[timedelta(seconds=1), None], id=f"mcp-server-{uuid.uuid4()}", task_queue=worker.task_queue, execution_timeout=timedelta(seconds=30), diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index 2d76cf765..6db463392 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,11 +1,9 @@ from pathlib import Path import pytest -from agents import OpenAIProvider -from openai import AsyncOpenAI from temporalio.client import WorkflowHistory -from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin +from temporalio.contrib.openai_agents import OpenAIAgentsPlugin from temporalio.worker import Replayer from tests.contrib.openai_agents.test_openai import ( AgentsAsToolsWorkflow,