diff --git a/pyproject.toml b/pyproject.toml index 9419d35fe..8fddea817 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ opentelemetry = [ pydantic = ["pydantic>=2.0.0,<3"] openai-agents = [ "openai-agents>=0.3,<0.4", - "eval-type-backport>=0.2.2; python_version < '3.10'" + "eval-type-backport>=0.2.2; python_version < '3.10'", + "mcp>=1.9.4, <2; python_version >= '3.10'", ] [project.urls] diff --git a/temporalio/contrib/openai_agents/README.md b/temporalio/contrib/openai_agents/README.md index d28ceb7b5..12c29bec5 100644 --- a/temporalio/contrib/openai_agents/README.md +++ b/temporalio/contrib/openai_agents/README.md @@ -351,6 +351,111 @@ Of course, code running in the workflow can invoke a Temporal activity at any ti Tools that run in the workflow can also update OpenAI Agents context, which is read-only for tools run as Temporal activities. +## MCP Support + +This integration provides support for Model Context Protocol (MCP) servers through two wrapper approaches designed to handle different implications of failures. + +While Temporal provides durable execution for your workflows, this durability does not extend to MCP servers, which operate independently of the workflow and must provide their own durability. The integration handles this by offering stateless and stateful wrappers that you can choose based on your MCP server's design. + +### Stateless vs Stateful MCP Servers + +You need to understand your MCP server's behavior to choose the correct wrapper: + +**Stateless MCP servers** treat each operation independently. For example, a weather server with a `get_weather(location)` tool is stateless because each call is self-contained and includes all necessary information. These servers can be safely restarted or reconnected to without changing their behavior. + +**Stateful MCP servers** maintain session state between calls. For example, a weather server that requires calling `set_location(location)` followed by `get_weather()` is stateful because it remembers the configured location and uses it for subsequent calls. If the session or the server is restarted, state crucial for operation is lost. Temporal identifies such failures and raises an `ApplicationError` to signal the need for application-level failure handling. + +### Usage Example (Stateless MCP) + +The code below gives an example of using a stateless MCP server. + +#### Worker Configuration + +```python +import asyncio +from datetime import timedelta +from agents.mcp import MCPServerStdio +from temporalio.client import Client +from temporalio.contrib.openai_agents import ( + ModelActivityParameters, + OpenAIAgentsPlugin, + StatelessMCPServerProvider, +) +from temporalio.worker import Worker + +async def main(): + # Create the MCP server provider + filesystem_server = StatelessMCPServerProvider( + lambda: MCPServerStdio( + name="FileSystemServer", + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/files"], + }, + ) + ) + + # Register the MCP server with the OpenAI Agents plugin + client = await Client.connect( + "localhost:7233", + plugins=[ + OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=60) + ), + mcp_servers=[filesystem_server], + ), + ], + ) + + worker = Worker( + client, + task_queue="my-task-queue", + workflows=[FileSystemWorkflow], + ) + await worker.run() + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Workflow Implementation + +```python +from temporalio import workflow +from temporalio.contrib import openai_agents +from agents import Agent, Runner + +@workflow.defn +class FileSystemWorkflow: + @workflow.run + async def run(self, query: str) -> str: + # Reference the MCP server by name (matches name in worker configuration) + server = openai_agents.workflow.stateless_mcp_server("FileSystemServer") + + agent = Agent( + name="File Assistant", + instructions="Use the filesystem tools to read files and answer questions.", + mcp_servers=[server], + ) + + result = await Runner.run(agent, input=query) + return result.final_output +``` + +The `StatelessMCPServerProvider` takes a factory function that creates new MCP server instances. The server name used in `stateless_mcp_server()` must match the name configured in the MCP server instance. In this example, the name is `FileSystemServer`. + +### Stateful MCP Servers + +For implementation details and examples, see the [samples repository](https://github.com/temporalio/samples-python/tree/main/openai_agents/mcp). + +When using stateful servers, the dedicated worker maintaining the connection may fail due to network issues or server problems. When this happens, Temporal raises an `ApplicationError` and cannot automatically recover because it cannot restore the lost server state. +To recover from such failures, you need to implement your own application-level retry logic. + +### Hosted MCP Tool + +For network-accessible MCP servers, you can also use `HostedMCPTool` from the OpenAI Agents SDK, which uses an MCP client hosted by OpenAI. + ## Feature Support This integration is presently subject to certain limitations. @@ -403,14 +508,17 @@ As described in [Tool Calling](#tool-calling), context propagation is read-only ### MCP -Presently, MCP is supported only via `HostedMCPTool`, which uses the OpenAI Responses API and cloud MCP client behind it. -The OpenAI Agents SDK also supports MCP clients that run in application code, but this integration does not. +The MCP protocol is stateful, but many MCP servers are stateless. +We let you choose between two MCP wrappers, one designed for stateless MCP servers and one for stateful MCP servers. +These wrappers work with all transport varieties. + +Note that when using network-accessible MCP servers, you also can also use the tool `HostedMCPTool`, which is part of the OpenAI Responses API and uses an MCP client hosted by OpenAI. | MCP Class | Supported | |:-----------------------|:---------:| -| MCPServerStdio | No | -| MCPServerSse | No | -| MCPServerStreamableHttp| No | +| MCPServerStdio | Yes | +| MCPServerSse | Yes | +| MCPServerStreamableHttp| Yes | ### Guardrails diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 2d9777aa8..4074d1ebd 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -8,6 +8,15 @@ Use with caution in production environments. """ +# Best Effort mcp, as it is not supported on Python 3.9 +try: + from temporalio.contrib.openai_agents._mcp import ( + StatefulMCPServerProvider, + StatelessMCPServerProvider, + ) +except ImportError: + pass + from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, @@ -27,6 +36,8 @@ "ModelActivityParameters", "OpenAIAgentsPlugin", "OpenAIPayloadConverter", + "StatelessMCPServerProvider", + "StatefulMCPServerProvider", "TestModel", "TestModelProvider", "workflow", diff --git a/temporalio/contrib/openai_agents/_mcp.py b/temporalio/contrib/openai_agents/_mcp.py new file mode 100644 index 000000000..e41cc2d3a --- /dev/null +++ b/temporalio/contrib/openai_agents/_mcp.py @@ -0,0 +1,414 @@ +import abc +import asyncio +import functools +import logging +from contextlib import AbstractAsyncContextManager +from datetime import timedelta +from typing import Any, Callable, Optional, Sequence, Union + +from agents import AgentBase, RunContextWrapper +from agents.mcp import MCPServer +from mcp import GetPromptResult, ListPromptsResult # type:ignore +from mcp import Tool as MCPTool # type:ignore +from mcp.types import CallToolResult # type:ignore + +from temporalio import activity, workflow +from temporalio.api.enums.v1.workflow_pb2 import ( + TIMEOUT_TYPE_HEARTBEAT, + TIMEOUT_TYPE_SCHEDULE_TO_START, +) +from temporalio.exceptions import ( + ActivityError, + ApplicationError, + CancelledError, + is_cancelled_exception, +) +from temporalio.worker import PollerBehaviorSimpleMaximum, Worker +from temporalio.workflow import ActivityConfig, ActivityHandle + +logger = logging.getLogger(__name__) + + +class _StatelessMCPServerReference(MCPServer): + def __init__( + self, + server: str, + config: Optional[ActivityConfig], + cache_tools_list: bool, + ): + self._name = server + "-stateless" + self._config = config or ActivityConfig( + start_to_close_timeout=timedelta(minutes=1) + ) + self._cache_tools_list = cache_tools_list + self._tools = None + super().__init__() + + @property + def name(self) -> str: + return self._name + + async def connect(self) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def list_tools( + self, + run_context: Optional[RunContextWrapper[Any]] = None, + agent: Optional[AgentBase] = None, + ) -> list[MCPTool]: + if self._tools: + return self._tools + tools = await workflow.execute_activity( + self.name + "-list-tools", + args=[], + result_type=list[MCPTool], + **self._config, + ) + if self._cache_tools_list: + self._tools = tools + return tools + + async def call_tool( + self, tool_name: str, arguments: Optional[dict[str, Any]] + ) -> CallToolResult: + return await workflow.execute_activity( + self.name + "-call-tool", + args=[tool_name, arguments], + result_type=CallToolResult, + **self._config, + ) + + async def list_prompts(self) -> ListPromptsResult: + return await workflow.execute_activity( + self.name + "-list-prompts", + args=[], + result_type=ListPromptsResult, + **self._config, + ) + + async def get_prompt( + self, name: str, arguments: Optional[dict[str, Any]] = None + ) -> GetPromptResult: + return await workflow.execute_activity( + self.name + "-get-prompt", + args=[name, arguments], + result_type=GetPromptResult, + **self._config, + ) + + +class StatelessMCPServerProvider: + """A stateless MCP server implementation for Temporal workflows. + + This class wraps a function to create MCP servers to make them stateless by executing each MCP operation + as a separate Temporal activity. Each operation (list_tools, call_tool, etc.) will + connect to the underlying server, execute the operation, and then clean up the connection. + + This approach will not maintain state across calls. If the desired MCPServer needs persistent state in order to + function, this cannot be used. + """ + + def __init__(self, server_factory: Callable[[], MCPServer]): + """Initialize the stateless temporal MCP server. + + Args: + server_factory: A function which will produce MCPServer instances. It should return a new server each time + so that state is not shared between workflow runs + """ + self._server_factory = server_factory + self._name = server_factory().name + "-stateless" + super().__init__() + + @property + def name(self) -> str: + """Get the server name.""" + return self._name + + def _get_activities(self) -> Sequence[Callable]: + @activity.defn(name=self.name + "-list-tools") + async def list_tools() -> list[MCPTool]: + server = self._server_factory() + try: + await server.connect() + return await server.list_tools() + finally: + await server.cleanup() + + @activity.defn(name=self.name + "-call-tool") + async def call_tool( + tool_name: str, arguments: Optional[dict[str, Any]] + ) -> CallToolResult: + server = self._server_factory() + try: + await server.connect() + return await server.call_tool(tool_name, arguments) + finally: + await server.cleanup() + + @activity.defn(name=self.name + "-list-prompts") + async def list_prompts() -> ListPromptsResult: + server = self._server_factory() + try: + await server.connect() + return await server.list_prompts() + finally: + await server.cleanup() + + @activity.defn(name=self.name + "-get-prompt") + async def get_prompt( + name: str, arguments: Optional[dict[str, Any]] + ) -> GetPromptResult: + server = self._server_factory() + try: + await server.connect() + return await server.get_prompt(name, arguments) + finally: + await server.cleanup() + + return list_tools, call_tool, list_prompts, get_prompt + + +def _handle_worker_failure(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except ActivityError as e: + failure = e.failure + if failure: + cause = failure.cause + if cause: + if ( + cause.timeout_failure_info.timeout_type + == TIMEOUT_TYPE_SCHEDULE_TO_START + ): + raise ApplicationError( + "MCP Stateful Server Worker failed to schedule activity.", + type="DedicatedWorkerFailure", + ) from e + if ( + cause.timeout_failure_info.timeout_type + == TIMEOUT_TYPE_HEARTBEAT + ): + raise ApplicationError( + "MCP Stateful Server Worker failed to heartbeat.", + type="DedicatedWorkerFailure", + ) from e + raise e + + return wrapper + + +class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager): + def __init__( + self, + server: str, + config: Optional[ActivityConfig], + server_session_config: Optional[ActivityConfig], + ): + self._name = server + "-stateful" + self._config = config or ActivityConfig( + start_to_close_timeout=timedelta(minutes=1), + schedule_to_start_timeout=timedelta(seconds=30), + ) + self._server_session_config = server_session_config or ActivityConfig( + start_to_close_timeout=timedelta(hours=1), + ) + self._connect_handle: Optional[ActivityHandle] = None + super().__init__() + + @property + def name(self) -> str: + return self._name + + async def connect(self) -> None: + self._config["task_queue"] = self.name + "@" + workflow.info().run_id + self._connect_handle = workflow.start_activity( + self.name + "-server-session", + args=[], + **self._server_session_config, + ) + + async def cleanup(self) -> None: + if self._connect_handle: + self._connect_handle.cancel() + try: + await self._connect_handle + except Exception as e: + if is_cancelled_exception(e): + pass + else: + raise + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.cleanup() + + @_handle_worker_failure + async def list_tools( + self, + run_context: Optional[RunContextWrapper[Any]] = None, + agent: Optional[AgentBase] = None, + ) -> list[MCPTool]: + if not self._connect_handle: + raise ApplicationError( + "Stateful MCP Server not connected. Call connect first." + ) + return await workflow.execute_activity( + self.name + "-list-tools", + args=[], + result_type=list[MCPTool], + **self._config, + ) + + @_handle_worker_failure + async def call_tool( + self, tool_name: str, arguments: Optional[dict[str, Any]] + ) -> CallToolResult: + if not self._connect_handle: + raise ApplicationError( + "Stateful MCP Server not connected. Call connect first." + ) + return await workflow.execute_activity( + self.name + "-call-tool", + args=[tool_name, arguments], + result_type=CallToolResult, + **self._config, + ) + + @_handle_worker_failure + async def list_prompts(self) -> ListPromptsResult: + if not self._connect_handle: + raise ApplicationError( + "Stateful MCP Server not connected. Call connect first." + ) + return await workflow.execute_activity( + self.name + "-list-prompts", + args=[], + result_type=ListPromptsResult, + **self._config, + ) + + @_handle_worker_failure + async def get_prompt( + self, name: str, arguments: Optional[dict[str, Any]] = None + ) -> GetPromptResult: + if not self._connect_handle: + raise ApplicationError( + "Stateful MCP Server not connected. Call connect first." + ) + return await workflow.execute_activity( + self.name + "-get-prompt", + args=[name, arguments], + result_type=GetPromptResult, + **self._config, + ) + + +class StatefulMCPServerProvider: + """A stateful MCP server implementation for Temporal workflows. + + This class wraps an function to create MCP servers to maintain a persistent connection throughout + the workflow execution. It creates a dedicated worker that stays connected to + the MCP server and processes operations on a dedicated task queue. + + This approach will allow the MCPServer to maintain state across calls if needed, but the caller + will have to handle cases where the dedicated worker fails, as Temporal is unable to seamlessly + recreate any lost state in that case. It is discouraged to use this approach unless necessary. + + Handling dedicated worker failure will entail catching ApplicationError with type "DedicatedWorkerFailure". + Depending on the usage pattern, the caller will then have to either restart from the point at which the Stateful + server was needed or handle continuing from that loss of state in some other way. + """ + + def __init__( + self, + server_factory: Callable[[], MCPServer], + ): + """Initialize the stateful temporal MCP server. + + Args: + server_factory: A function which will produce MCPServer instances. It should return a new server each time + so that state is not shared between workflow runs + """ + self._server_factory = server_factory + self._name = server_factory().name + "-stateful" + self._connect_handle: Optional[ActivityHandle] = None + self._servers: dict[str, MCPServer] = {} + super().__init__() + + @property + def name(self) -> str: + """Get the server name.""" + return self._name + + def _get_activities(self) -> Sequence[Callable]: + def _server_id(): + return self.name + "@" + activity.info().workflow_run_id + + @activity.defn(name=self.name + "-list-tools") + async def list_tools() -> list[MCPTool]: + return await self._servers[_server_id()].list_tools() + + @activity.defn(name=self.name + "-call-tool") + async def call_tool( + tool_name: str, arguments: Optional[dict[str, Any]] + ) -> CallToolResult: + return await self._servers[_server_id()].call_tool(tool_name, arguments) + + @activity.defn(name=self.name + "-list-prompts") + async def list_prompts() -> ListPromptsResult: + return await self._servers[_server_id()].list_prompts() + + @activity.defn(name=self.name + "-get-prompt") + async def get_prompt( + name: str, arguments: Optional[dict[str, Any]] + ) -> GetPromptResult: + return await self._servers[_server_id()].get_prompt(name, arguments) + + async def heartbeat_every(delay: float, *details: Any) -> None: + """Heartbeat every so often while not cancelled""" + while True: + await asyncio.sleep(delay) + activity.heartbeat(*details) + + @activity.defn(name=self.name + "-server-session") + async def connect() -> None: + heartbeat_task = asyncio.create_task(heartbeat_every(30)) + + server_id = self.name + "@" + activity.info().workflow_run_id + if server_id in self._servers: + raise ApplicationError( + "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow." + ) + server = self._server_factory() + try: + self._servers[server_id] = server + try: + await server.connect() + + worker = Worker( + activity.client(), + task_queue=server_id, + activities=[list_tools, call_tool, list_prompts, get_prompt], + activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1), + ) + + await worker.run() + finally: + await server.cleanup() + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass + finally: + del self._servers[server_id] + + return (connect,) diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index f7030cdd4..de97987d3 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -59,10 +59,23 @@ async def run( ) if starting_agent.mcp_servers: - raise ValueError( - "Temporal OpenAI agent does not support on demand MCP servers." + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, ) + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) hooks = kwargs.get("hooks") diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index ea2e1ebc0..49b186d98 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,9 +1,10 @@ """Initialize Temporal OpenAI Agents overrides.""" import dataclasses +import typing from contextlib import asynccontextmanager, contextmanager from datetime import timedelta -from typing import AsyncIterator, Callable, Optional, Union +from typing import AsyncIterator, Callable, Optional, Sequence, Union from agents import ( AgentOutputSchemaBase, @@ -53,6 +54,19 @@ WorkerConfig, WorkflowReplayResult, ) +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + +# Unsupported on python 3.9 +try: + from agents.mcp import MCPServer +except ImportError: + pass + +if typing.TYPE_CHECKING: + from temporalio.contrib.openai_agents import ( + StatefulMCPServerProvider, + StatelessMCPServerProvider, + ) @contextmanager @@ -101,6 +115,8 @@ def set_open_ai_agent_temporal_overrides( class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module.""" + __test__ = False + def __init__(self, model: Model): """Initialize a test model provider with a model.""" self._model = model @@ -113,6 +129,8 @@ def get_model(self, model_name: Union[str, None]) -> Model: class TestModel(Model): """Test model for use mocking model responses.""" + __test__ = False + def __init__(self, fn: Callable[[], ModelResponse]) -> None: """Initialize a test model with a callable.""" self.fn = fn @@ -170,18 +188,24 @@ class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): 1. Configures the Pydantic data converter for type-safe serialization 2. Sets up tracing interceptors for OpenAI agent interactions 3. Registers model execution activities - 4. Manages the OpenAI agent runtime overrides during worker execution + 4. Automatically registers MCP server activities and manages their lifecycles + 5. Manages the OpenAI agent runtime overrides during worker execution Args: model_params: Configuration parameters for Temporal activity execution of model calls. If None, default parameters will be used. model_provider: Optional model provider for custom model implementations. Useful for testing or custom model integrations. + mcp_servers: Sequence of MCP servers to automatically register with the worker. + The plugin will wrap each server in a TemporalMCPServer if needed and + manage their connection lifecycles tied to the worker lifetime. This is + the recommended way to use MCP servers with Temporal workflows. Example: >>> from temporalio.client import Client >>> from temporalio.worker import Worker - >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider + >>> from agents.mcp import MCPServerStdio >>> from datetime import timedelta >>> >>> # Configure model parameters @@ -190,8 +214,17 @@ class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): ... retry_policy=RetryPolicy(maximum_attempts=3) ... ) >>> - >>> # Create plugin - >>> plugin = OpenAIAgentsPlugin(model_params=model_params) + >>> # Create MCP servers + >>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio( + ... name="Filesystem Server", + ... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]} + ... )) + >>> + >>> # Create plugin with MCP servers + >>> plugin = OpenAIAgentsPlugin( + ... model_params=model_params, + ... mcp_servers=[filesystem_server] + ... ) >>> >>> # Use with client and worker >>> client = await Client.connect( @@ -209,6 +242,9 @@ def __init__( self, model_params: Optional[ModelActivityParameters] = None, model_provider: Optional[ModelProvider] = None, + mcp_servers: Sequence[ + Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] + ] = (), ) -> None: """Initialize the OpenAI agents plugin. @@ -217,6 +253,10 @@ def __init__( of model calls. If None, default parameters will be used. model_provider: Optional model provider for custom model implementations. Useful for testing or custom model integrations. + mcp_servers: Sequence of MCP servers to automatically register with the worker. + Each server will be wrapped in a TemporalMCPServer if not already wrapped, + and their activities will be automatically registered with the worker. + The plugin manages the connection lifecycle of these servers. """ if model_params is None: model_params = ModelActivityParameters() @@ -236,6 +276,7 @@ def __init__( self._model_params = model_params self._model_provider = model_provider + self._mcp_servers = mcp_servers def init_client_plugin(self, next: temporalio.client.Plugin) -> None: """Set the next client plugin""" @@ -297,9 +338,25 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["interceptors"] = list(config.get("interceptors") or []) + [ OpenAIAgentsTracingInterceptor() ] - config["activities"] = list(config.get("activities") or []) + [ - ModelActivity(self._model_provider).invoke_model_activity - ] + new_activities = [ModelActivity(self._model_provider).invoke_model_activity] + + server_names = [server.name for server in self._mcp_servers] + if len(server_names) != len(set(server_names)): + raise ValueError( + f"More than one mcp server registered with the same name. Please provide unique names." + ) + + for mcp_server in self._mcp_servers: + new_activities.extend(mcp_server._get_activities()) + config["activities"] = list(config.get("activities") or []) + new_activities + + runner = config.get("workflow_runner") + if isinstance(runner, SandboxedWorkflowRunner): + config["workflow_runner"] = dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules("mcp"), + ) + config["workflow_failure_exception_types"] = list( config.get("workflow_failure_exception_types") or [] ) + [AgentsWorkflowError] diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index 2f69866ce..63ec43154 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -3,31 +3,33 @@ import functools import inspect import json +import typing +from contextlib import AbstractAsyncContextManager from datetime import timedelta -from typing import Any, Callable, Optional, Type, Union, overload +from typing import Any, Callable, Optional, Type import nexusrpc from agents import ( - Agent, RunContextWrapper, Tool, ) -from agents.function_schema import DocstringStyle, function_schema +from agents.function_schema import function_schema from agents.tool import ( FunctionTool, - ToolErrorFunction, - ToolFunction, - ToolParams, - default_tool_error_function, - function_tool, ) -from agents.util._types import MaybeAwaitable from temporalio import activity from temporalio import workflow as temporal_workflow from temporalio.common import Priority, RetryPolicy from temporalio.exceptions import ApplicationError, TemporalError -from temporalio.workflow import ActivityCancellationType, VersioningIntent +from temporalio.workflow import ( + ActivityCancellationType, + ActivityConfig, + VersioningIntent, +) + +if typing.TYPE_CHECKING: + from agents.mcp import MCPServer def activity_as_tool( @@ -239,6 +241,73 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: ) +def stateless_mcp_server( + name: str, + config: Optional[ActivityConfig] = None, + cache_tools_list: bool = False, +) -> "MCPServer": + """A stateless MCP server implementation for Temporal workflows. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + This uses a TemporalMCPServer of the same name registered with the OpenAIAgents plugin to implement + durable MCP operations statelessly. + + This approach is suitable for simple use cases where connection overhead is acceptable + and you don't need to maintain state between operations. It should be preferred to stateful when possible due to its + superior durability guarantees. + + Args: + name: A string name for the server. Should match that provided in the plugin. + config: Optional activity configuration for MCP operation activities. + Defaults to 1-minute start-to-close timeout. + cache_tools_list: If true, the list of tools will be cached for the duration of the server + """ + from temporalio.contrib.openai_agents._mcp import ( + _StatelessMCPServerReference, + ) + + return _StatelessMCPServerReference(name, config, cache_tools_list) + + +def stateful_mcp_server( + name: str, + config: Optional[ActivityConfig] = None, + server_session_config: Optional[ActivityConfig] = None, +) -> AbstractAsyncContextManager["MCPServer"]: + """A stateful MCP server implementation for Temporal workflows. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + This wraps an MCP server to maintain a persistent connection throughout + the workflow execution. It creates a dedicated worker that stays connected to + the MCP server and processes operations on a dedicated task queue. + + This approach is more efficient for workflows that make multiple MCP calls, + as it avoids connection overhead, but requires more resources to maintain + the persistent connection and worker. + + The caller will have to handle cases where the dedicated worker fails, as Temporal is + unable to seamlessly recreate any lost state in that case. + + Args: + name: A string name for the server. Should match that provided in the plugin. + config: Optional activity configuration for MCP operation activities. + Defaults to 1-minute start-to-close and 30-second schedule-to-start timeouts. + server_session_config: Optional activity configuration for the connection activity. + Defaults to 1-hour start-to-close timeout. + """ + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + ) + + return _StatefulMCPServerReference(name, config, server_session_config) + + class ToolSerializationError(TemporalError): """Error that occurs when a tool output could not be serialized. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index ca45a54b8..bcaeccfd9 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -5,13 +5,23 @@ import uuid from dataclasses import dataclass from datetime import timedelta -from typing import Any, AsyncIterator, Optional, Union, no_type_check +from typing import ( + Any, + AsyncIterator, + Callable, + Optional, + Sequence, + Union, + cast, + no_type_check, +) import nexusrpc import pydantic import pytest from agents import ( Agent, + AgentBase, AgentOutputSchemaBase, CodeInterpreterTool, FileSearchTool, @@ -21,7 +31,6 @@ ImageGenerationTool, InputGuardrailTripwireTriggered, ItemHelpers, - LocalShellTool, MCPToolApprovalFunctionResult, MCPToolApprovalRequest, MessageOutputItem, @@ -77,10 +86,9 @@ from openai.types.responses.response_prompt_param import ResponsePromptParam from pydantic import ConfigDict, Field, TypeAdapter -import temporalio.api.cloud.namespace.v1 from temporalio import activity, workflow from temporalio.client import Client, WorkflowFailureError, WorkflowHandle -from temporalio.common import RetryPolicy, SearchAttributeValueType +from temporalio.common import RetryPolicy from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( ModelActivityParameters, @@ -92,6 +100,7 @@ from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, CancelledError from temporalio.testing import WorkflowEnvironment +from temporalio.workflow import ActivityConfig from tests.contrib.openai_agents.research_agents.research_manager import ( ResearchManager, ) @@ -2275,3 +2284,290 @@ async def test_output_type(client: Client): result = await workflow_handle.result() assert isinstance(result, OutputType) assert result.answer == "My answer" + + +@workflow.defn +class McpServerWorkflow: + @workflow.run + async def run(self, caching: bool) -> str: + from agents.mcp import MCPServer + + server: MCPServer = openai_agents.workflow.stateless_mcp_server( + "HelloServer", cache_tools_list=caching + ) + agent = Agent[str]( + name="MCP ServerWorkflow", + instructions="Use the tools to assist the customer.", + mcp_servers=[server], + ) + result = await Runner.run( + starting_agent=agent, input="Say hello to Tom and Tim." + ) + return result.final_output + + +@workflow.defn +class McpServerStatefulWorkflow: + @workflow.run + async def run(self, timeout: timedelta) -> str: + async with openai_agents.workflow.stateful_mcp_server( + "HelloServer", + config=ActivityConfig( + schedule_to_start_timeout=timeout, + start_to_close_timeout=timedelta(seconds=30), + ), + ) as server: + agent = Agent[str]( + name="MCP ServerWorkflow", + instructions="Use the tools to assist the customer.", + mcp_servers=[server], + ) + result = await Runner.run( + starting_agent=agent, input="Say hello to Tom and Tim." + ) + return result.final_output + + +class TrackingMCPModel(StaticTestModel): + responses = [ + ResponseBuilders.tool_call( + arguments='{"name":"Tom"}', + name="Say-Hello", + ), + ResponseBuilders.tool_call( + arguments='{"name":"Tim"}', + name="Say-Hello", + ), + ResponseBuilders.output_message("Hi Tom and Tim!"), + ] + + +@pytest.mark.parametrize("use_local_model", [True, False]) +@pytest.mark.parametrize("stateful", [True, False]) +@pytest.mark.parametrize("caching", [True, False]) +async def test_mcp_server( + client: Client, use_local_model: bool, stateful: bool, caching: bool +): + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): + pytest.skip("No openai API key") + + if sys.version_info < (3, 10): + pytest.skip("Mcp not supported on Python 3.9") + + if stateful and caching: + pytest.skip("Caching is only supported for stateless MCP servers") + + from agents.mcp import MCPServer + from mcp import GetPromptResult, ListPromptsResult # type: ignore + from mcp import Tool as MCPTool # type: ignore + from mcp.types import CallToolResult, TextContent # type: ignore + + from temporalio.contrib.openai_agents import ( + StatefulMCPServerProvider, + StatelessMCPServerProvider, + ) + + class TrackingMCPServer(MCPServer): + calls: list[str] + + def __init__(self, name: str): + self._name = name + self.calls = [] + super().__init__() + + async def connect(self): + self.calls.append("connect") + + @property + def name(self) -> str: + return self._name + + async def cleanup(self): + self.calls.append("cleanup") + + async def list_tools( + self, + run_context: Optional[RunContextWrapper[Any]] = None, + agent: Optional[AgentBase] = None, + ) -> list[MCPTool]: + self.calls.append("list_tools") + return [ + MCPTool( + name="Say-Hello", + inputSchema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + "$schema": "http://json-schema.org/draft-07/schema#", + }, + ) + ] + + async def call_tool( + self, tool_name: str, arguments: Optional[dict[str, Any]] + ) -> CallToolResult: + self.calls.append("call_tool") + name = (arguments or {}).get("name") or "John Doe" + return CallToolResult( + content=[TextContent(type="text", text=f"Hello {name}")] + ) + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError() + + async def get_prompt( + self, name: str, arguments: Optional[dict[str, Any]] = None + ) -> GetPromptResult: + raise NotImplementedError() + + tracking_server = TrackingMCPServer(name="HelloServer") + server: Union[StatefulMCPServerProvider, StatelessMCPServerProvider] = ( + StatefulMCPServerProvider(lambda: tracking_server) + if stateful + else StatelessMCPServerProvider(lambda: tracking_server) + ) + + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=120) + ), + model_provider=TestModelProvider(TrackingMCPModel()) + if use_local_model + else None, + mcp_servers=[server], + ) + ] + client = Client(**new_config) + + async with new_worker( + client, McpServerStatefulWorkflow, McpServerWorkflow + ) as worker: + if stateful: + result = await client.execute_workflow( + McpServerStatefulWorkflow.run, + args=[timedelta(seconds=30)], + id=f"mcp-server-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + else: + result = await client.execute_workflow( + McpServerWorkflow.run, + args=[caching], + id=f"mcp-server-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + if use_local_model: + assert result == "Hi Tom and Tim!" + if use_local_model: + print(tracking_server.calls) + if stateful: + assert tracking_server.calls == [ + "connect", + "list_tools", + "call_tool", + "list_tools", + "call_tool", + "list_tools", + "cleanup", + ] + assert len(cast(StatefulMCPServerProvider, server)._servers) == 0 + else: + if caching: + assert tracking_server.calls == [ + "connect", + "list_tools", + "cleanup", + "connect", + "call_tool", + "cleanup", + "connect", + "call_tool", + "cleanup", + ] + else: + assert tracking_server.calls == [ + "connect", + "list_tools", + "cleanup", + "connect", + "call_tool", + "cleanup", + "connect", + "list_tools", + "cleanup", + "connect", + "call_tool", + "cleanup", + "connect", + "list_tools", + "cleanup", + ] + + +async def test_stateful_mcp_server_no_worker(client: Client): + if sys.version_info < (3, 10): + pytest.skip("Mcp not supported on Python 3.9") + from agents.mcp import MCPServerStdio + + from temporalio.contrib.openai_agents import StatefulMCPServerProvider + + server = StatefulMCPServerProvider( + lambda: MCPServerStdio( + name="Filesystem-Server", + params={ + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + os.path.dirname(os.path.abspath(__file__)), + ], + }, + ) + ) + + # Override the connect activity to not actually start a worker + @activity.defn(name="Filesystem-Server-stateful-connect") + async def connect() -> None: + await asyncio.sleep(30) + + def override_get_activities() -> Sequence[Callable]: + return (connect,) + + server.get_activities = override_get_activities # type:ignore + + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=120) + ), + model_provider=TestModelProvider(TrackingMCPModel()), + mcp_servers=[server], + ) + ] + client = Client(**new_config) + + async with new_worker( + client, + McpServerStatefulWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + McpServerStatefulWorkflow.run, + args=[timedelta(seconds=1)], + id=f"mcp-server-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + with pytest.raises(WorkflowFailureError) as err: + await workflow_handle.result() + assert isinstance(err.value.cause, ApplicationError) + assert ( + err.value.cause.message + == "MCP Stateful Server Worker failed to schedule activity." + ) diff --git a/uv.lock b/uv.lock index f990b7372..e12ecda13 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -2809,6 +2809,7 @@ grpc = [ ] openai-agents = [ { name = "eval-type-backport", marker = "python_full_version < '3.10'" }, + { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "openai-agents" }, ] opentelemetry = [ @@ -2846,11 +2847,12 @@ dev = [ requires-dist = [ { name = "eval-type-backport", marker = "python_full_version < '3.10' and extra == 'openai-agents'", specifier = ">=0.2.2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.1.0" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.4" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, - { name = "protobuf", specifier = ">=6.31.1,<7.0.0" }, + { name = "protobuf", specifier = ">=3.20,<7.0.0" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.0.0,<3" }, { name = "python-dateutil", marker = "python_full_version < '3.11'", specifier = ">=2.8.2,<3" }, { name = "types-protobuf", specifier = ">=3.20" },