Skip to content

Commit 2a999a2

Browse files
committed
Change stateless to a provider model, add caching option
1 parent 975b2e1 commit 2a999a2

File tree

5 files changed

+136
-69
lines changed

5 files changed

+136
-69
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
try:
1313
from temporalio.contrib.openai_agents._mcp import (
1414
StatefulMCPServerProvider,
15-
StatelessMCPServer,
15+
StatelessMCPServerProvider,
1616
)
1717
except ImportError:
1818
pass
@@ -36,7 +36,7 @@
3636
"ModelActivityParameters",
3737
"OpenAIAgentsPlugin",
3838
"OpenAIPayloadConverter",
39-
"StatelessMCPServer",
39+
"StatelessMCPServerProvider",
4040
"StatefulMCPServerProvider",
4141
"TestModel",
4242
"TestModelProvider",

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@
3030

3131

3232
class _StatelessMCPServerReference(MCPServer):
33-
def __init__(self, server: str, config: Optional[ActivityConfig] = None):
33+
def __init__(
34+
self,
35+
server: str,
36+
config: Optional[ActivityConfig],
37+
cache_tools_list: bool,
38+
):
3439
self._name = server + "-stateless"
3540
self._config = config or ActivityConfig(
3641
start_to_close_timeout=timedelta(minutes=1)
3742
)
43+
self._cache_tools_list = cache_tools_list
44+
self._tools = None
3845
super().__init__()
3946

4047
@property
@@ -52,12 +59,17 @@ async def list_tools(
5259
run_context: Optional[RunContextWrapper[Any]] = None,
5360
agent: Optional[AgentBase] = None,
5461
) -> list[MCPTool]:
55-
return await workflow.execute_activity(
62+
if self._tools:
63+
return self._tools
64+
tools = await workflow.execute_activity(
5665
self.name + "-list-tools",
5766
args=[],
5867
result_type=list[MCPTool],
5968
**self._config,
6069
)
70+
if self._cache_tools_list:
71+
self._tools = tools
72+
return tools
6173

6274
async def call_tool(
6375
self, tool_name: str, arguments: Optional[dict[str, Any]]
@@ -88,25 +100,26 @@ async def get_prompt(
88100
)
89101

90102

91-
class StatelessMCPServer:
103+
class StatelessMCPServerProvider:
92104
"""A stateless MCP server implementation for Temporal workflows.
93105
94-
This class wraps an MCP server to make it stateless by executing each MCP operation
106+
This class wraps a function to create MCP servers to make them stateless by executing each MCP operation
95107
as a separate Temporal activity. Each operation (list_tools, call_tool, etc.) will
96108
connect to the underlying server, execute the operation, and then clean up the connection.
97109
98110
This approach will not maintain state across calls. If the desired MCPServer needs persistent state in order to
99111
function, this cannot be used.
100112
"""
101113

102-
def __init__(self, server: MCPServer):
114+
def __init__(self, server_factory: Callable[[], MCPServer]):
103115
"""Initialize the stateless temporal MCP server.
104116
105117
Args:
106-
server: An MCPServer instance
118+
server_factory: A function which will produce MCPServer instances. It should return a new server each time
119+
so that state is not shared between workflow runs
107120
"""
108-
self._server = server
109-
self._name = server.name + "-stateless"
121+
self._server_factory = server_factory
122+
self._name = server_factory().name + "-stateless"
110123
super().__init__()
111124

112125
@property
@@ -117,39 +130,43 @@ def name(self) -> str:
117130
def _get_activities(self) -> Sequence[Callable]:
118131
@activity.defn(name=self.name + "-list-tools")
119132
async def list_tools() -> list[MCPTool]:
133+
server = self._server_factory()
120134
try:
121-
await self._server.connect()
122-
return await self._server.list_tools()
135+
await server.connect()
136+
return await server.list_tools()
123137
finally:
124-
await self._server.cleanup()
138+
await server.cleanup()
125139

126140
@activity.defn(name=self.name + "-call-tool")
127141
async def call_tool(
128142
tool_name: str, arguments: Optional[dict[str, Any]]
129143
) -> CallToolResult:
144+
server = self._server_factory()
130145
try:
131-
await self._server.connect()
132-
return await self._server.call_tool(tool_name, arguments)
146+
await server.connect()
147+
return await server.call_tool(tool_name, arguments)
133148
finally:
134-
await self._server.cleanup()
149+
await server.cleanup()
135150

136151
@activity.defn(name=self.name + "-list-prompts")
137152
async def list_prompts() -> ListPromptsResult:
153+
server = self._server_factory()
138154
try:
139-
await self._server.connect()
140-
return await self._server.list_prompts()
155+
await server.connect()
156+
return await server.list_prompts()
141157
finally:
142-
await self._server.cleanup()
158+
await server.cleanup()
143159

144160
@activity.defn(name=self.name + "-get-prompt")
145161
async def get_prompt(
146162
name: str, arguments: Optional[dict[str, Any]]
147163
) -> GetPromptResult:
164+
server = self._server_factory()
148165
try:
149-
await self._server.connect()
150-
return await self._server.get_prompt(name, arguments)
166+
await server.connect()
167+
return await server.get_prompt(name, arguments)
151168
finally:
152-
await self._server.cleanup()
169+
await server.cleanup()
153170

154171
return list_tools, call_tool, list_prompts, get_prompt
155172

@@ -189,8 +206,9 @@ class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager):
189206
def __init__(
190207
self,
191208
server: str,
192-
config: Optional[ActivityConfig] = None,
193-
server_session_config: Optional[ActivityConfig] = None,
209+
config: Optional[ActivityConfig],
210+
server_session_config: Optional[ActivityConfig],
211+
cache_tools_list: bool,
194212
):
195213
self._name = server + "-stateful"
196214
self._config = config or ActivityConfig(
@@ -201,6 +219,8 @@ def __init__(
201219
start_to_close_timeout=timedelta(hours=1),
202220
)
203221
self._connect_handle: Optional[ActivityHandle] = None
222+
self._cache_tools_list = cache_tools_list
223+
self._tools = None
204224
super().__init__()
205225

206226
@property
@@ -239,16 +259,22 @@ async def list_tools(
239259
run_context: Optional[RunContextWrapper[Any]] = None,
240260
agent: Optional[AgentBase] = None,
241261
) -> list[MCPTool]:
262+
if self._tools:
263+
return self._tools
264+
242265
if not self._connect_handle:
243266
raise ApplicationError(
244267
"Stateful MCP Server not connected. Call connect first."
245268
)
246-
return await workflow.execute_activity(
269+
tools = await workflow.execute_activity(
247270
self.name + "-list-tools",
248271
args=[],
249272
result_type=list[MCPTool],
250273
**self._config,
251274
)
275+
if self._cache_tools_list:
276+
self._tools = tools
277+
return tools
252278

253279
@_handle_worker_failure
254280
async def call_tool(
@@ -361,7 +387,7 @@ async def heartbeat_every(delay: float, *details: Any) -> None:
361387
await asyncio.sleep(delay)
362388
activity.heartbeat(*details)
363389

364-
@activity.defn(name=self._name + "-server-session")
390+
@activity.defn(name=self.name + "-server-session")
365391
async def connect() -> None:
366392
heartbeat_task = asyncio.create_task(heartbeat_every(30))
367393

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
if typing.TYPE_CHECKING:
6666
from temporalio.contrib.openai_agents import (
6767
StatefulMCPServerProvider,
68-
StatelessMCPServer,
68+
StatelessMCPServerProvider,
6969
)
7070

7171

@@ -204,7 +204,7 @@ class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
204204
Example:
205205
>>> from temporalio.client import Client
206206
>>> from temporalio.worker import Worker
207-
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServer
207+
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters, StatelessMCPServerProvider
208208
>>> from agents.mcp import MCPServerStdio
209209
>>> from datetime import timedelta
210210
>>>
@@ -215,7 +215,7 @@ class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
215215
... )
216216
>>>
217217
>>> # Create MCP servers
218-
>>> filesystem_server = StatelessMCPServer(MCPServerStdio(
218+
>>> filesystem_server = StatelessMCPServerProvider(MCPServerStdio(
219219
... name="Filesystem Server",
220220
... params={"command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]}
221221
... ))
@@ -243,7 +243,7 @@ def __init__(
243243
model_params: Optional[ModelActivityParameters] = None,
244244
model_provider: Optional[ModelProvider] = None,
245245
mcp_servers: Sequence[
246-
Union["StatelessMCPServer", "StatefulMCPServerProvider"]
246+
Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"]
247247
] = (),
248248
) -> None:
249249
"""Initialize the OpenAI agents plugin.

temporalio/contrib/openai_agents/workflow.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any:
242242

243243

244244
def stateless_mcp_server(
245-
name: str, config: Optional[ActivityConfig] = None
245+
name: str,
246+
config: Optional[ActivityConfig] = None,
247+
cache_tools_list: bool = False,
246248
) -> "MCPServer":
247249
"""A stateless MCP server implementation for Temporal workflows.
248250
@@ -256,18 +258,25 @@ def stateless_mcp_server(
256258
This approach is suitable for simple use cases where connection overhead is acceptable
257259
and you don't need to maintain state between operations. It should be preferred to stateful when possible due to its
258260
superior durability guarantees.
261+
262+
Args:
263+
name: A string name for the server. Should match that provided in the plugin.
264+
config: Optional activity configuration for MCP operation activities.
265+
Defaults to 1-minute start-to-close timeout.
266+
cache_tools_list: If true, the list of tools will be cached for the duration of the server
259267
"""
260268
from temporalio.contrib.openai_agents._mcp import (
261269
_StatelessMCPServerReference,
262270
)
263271

264-
return _StatelessMCPServerReference(name, config)
272+
return _StatelessMCPServerReference(name, config, cache_tools_list)
265273

266274

267275
def stateful_mcp_server(
268276
name: str,
269277
config: Optional[ActivityConfig] = None,
270278
server_session_config: Optional[ActivityConfig] = None,
279+
cache_tools_list: bool = False,
271280
) -> AbstractAsyncContextManager["MCPServer"]:
272281
"""A stateful MCP server implementation for Temporal workflows.
273282
@@ -292,12 +301,15 @@ def stateful_mcp_server(
292301
Defaults to 1-minute start-to-close and 30-second schedule-to-start timeouts.
293302
server_session_config: Optional activity configuration for the connection activity.
294303
Defaults to 1-hour start-to-close timeout.
304+
cache_tools_list: If true, the list of tools will be cached for the duration of the server
295305
"""
296306
from temporalio.contrib.openai_agents._mcp import (
297307
_StatefulMCPServerReference,
298308
)
299309

300-
return _StatefulMCPServerReference(name, config, server_session_config)
310+
return _StatefulMCPServerReference(
311+
name, config, server_session_config, cache_tools_list
312+
)
301313

302314

303315
class ToolSerializationError(TemporalError):

0 commit comments

Comments
 (0)