Skip to content

Commit 58708c0

Browse files
committed
Remove caching from stateful. Underlying server can handle it
1 parent 2a999a2 commit 58708c0

File tree

3 files changed

+18
-38
lines changed

3 files changed

+18
-38
lines changed

temporalio/contrib/openai_agents/_mcp.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def __init__(
208208
server: str,
209209
config: Optional[ActivityConfig],
210210
server_session_config: Optional[ActivityConfig],
211-
cache_tools_list: bool,
212211
):
213212
self._name = server + "-stateful"
214213
self._config = config or ActivityConfig(
@@ -219,8 +218,6 @@ def __init__(
219218
start_to_close_timeout=timedelta(hours=1),
220219
)
221220
self._connect_handle: Optional[ActivityHandle] = None
222-
self._cache_tools_list = cache_tools_list
223-
self._tools = None
224221
super().__init__()
225222

226223
@property
@@ -259,22 +256,16 @@ async def list_tools(
259256
run_context: Optional[RunContextWrapper[Any]] = None,
260257
agent: Optional[AgentBase] = None,
261258
) -> list[MCPTool]:
262-
if self._tools:
263-
return self._tools
264-
265259
if not self._connect_handle:
266260
raise ApplicationError(
267261
"Stateful MCP Server not connected. Call connect first."
268262
)
269-
tools = await workflow.execute_activity(
263+
return await workflow.execute_activity(
270264
self.name + "-list-tools",
271265
args=[],
272266
result_type=list[MCPTool],
273267
**self._config,
274268
)
275-
if self._cache_tools_list:
276-
self._tools = tools
277-
return tools
278269

279270
@_handle_worker_failure
280271
async def call_tool(

temporalio/contrib/openai_agents/workflow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def stateful_mcp_server(
276276
name: str,
277277
config: Optional[ActivityConfig] = None,
278278
server_session_config: Optional[ActivityConfig] = None,
279-
cache_tools_list: bool = False,
280279
) -> AbstractAsyncContextManager["MCPServer"]:
281280
"""A stateful MCP server implementation for Temporal workflows.
282281
@@ -301,15 +300,12 @@ def stateful_mcp_server(
301300
Defaults to 1-minute start-to-close and 30-second schedule-to-start timeouts.
302301
server_session_config: Optional activity configuration for the connection activity.
303302
Defaults to 1-hour start-to-close timeout.
304-
cache_tools_list: If true, the list of tools will be cached for the duration of the server
305303
"""
306304
from temporalio.contrib.openai_agents._mcp import (
307305
_StatefulMCPServerReference,
308306
)
309307

310-
return _StatefulMCPServerReference(
311-
name, config, server_session_config, cache_tools_list
312-
)
308+
return _StatefulMCPServerReference(name, config, server_session_config)
313309

314310

315311
class ToolSerializationError(TemporalError):

tests/contrib/openai_agents/test_openai.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,7 +2289,7 @@ async def test_output_type(client: Client):
22892289
@workflow.defn
22902290
class McpServerWorkflow:
22912291
@workflow.run
2292-
async def run(self, timeout: timedelta, caching: bool) -> str:
2292+
async def run(self, caching: bool) -> str:
22932293
from agents.mcp import MCPServer
22942294

22952295
server: MCPServer = openai_agents.workflow.stateless_mcp_server(
@@ -2309,14 +2309,13 @@ async def run(self, timeout: timedelta, caching: bool) -> str:
23092309
@workflow.defn
23102310
class McpServerStatefulWorkflow:
23112311
@workflow.run
2312-
async def run(self, timeout: timedelta, caching: bool) -> str:
2312+
async def run(self, timeout: timedelta) -> str:
23132313
async with openai_agents.workflow.stateful_mcp_server(
23142314
"HelloServer",
23152315
config=ActivityConfig(
23162316
schedule_to_start_timeout=timeout,
23172317
start_to_close_timeout=timedelta(seconds=30),
23182318
),
2319-
cache_tools_list=caching,
23202319
) as server:
23212320
agent = Agent[str](
23222321
name="MCP ServerWorkflow",
@@ -2355,6 +2354,9 @@ async def test_mcp_server(
23552354
if sys.version_info < (3, 10):
23562355
pytest.skip("Mcp not supported on Python 3.9")
23572356

2357+
if stateful and caching:
2358+
pytest.skip("Caching is only supported for stateless MCP servers")
2359+
23582360
from agents.mcp import MCPServer
23592361
from mcp import GetPromptResult, ListPromptsResult # type: ignore
23602362
from mcp import Tool as MCPTool # type: ignore
@@ -2447,15 +2449,15 @@ async def get_prompt(
24472449
if stateful:
24482450
result = await client.execute_workflow(
24492451
McpServerStatefulWorkflow.run,
2450-
args=[timedelta(seconds=30), caching],
2452+
args=[timedelta(seconds=30)],
24512453
id=f"mcp-server-{uuid.uuid4()}",
24522454
task_queue=worker.task_queue,
24532455
execution_timeout=timedelta(seconds=30),
24542456
)
24552457
else:
24562458
result = await client.execute_workflow(
24572459
McpServerWorkflow.run,
2458-
args=[timedelta(seconds=30), caching],
2460+
args=[caching],
24592461
id=f"mcp-server-{uuid.uuid4()}",
24602462
task_queue=worker.task_queue,
24612463
execution_timeout=timedelta(seconds=30),
@@ -2465,24 +2467,15 @@ async def get_prompt(
24652467
if use_local_model:
24662468
print(tracking_server.calls)
24672469
if stateful:
2468-
if caching:
2469-
assert tracking_server.calls == [
2470-
"connect",
2471-
"list_tools",
2472-
"call_tool",
2473-
"call_tool",
2474-
"cleanup",
2475-
]
2476-
else:
2477-
assert tracking_server.calls == [
2478-
"connect",
2479-
"list_tools",
2480-
"call_tool",
2481-
"list_tools",
2482-
"call_tool",
2483-
"list_tools",
2484-
"cleanup",
2485-
]
2470+
assert tracking_server.calls == [
2471+
"connect",
2472+
"list_tools",
2473+
"call_tool",
2474+
"list_tools",
2475+
"call_tool",
2476+
"list_tools",
2477+
"cleanup",
2478+
]
24862479
assert len(cast(StatefulMCPServerProvider, server)._servers) == 0
24872480
else:
24882481
if caching:

0 commit comments

Comments
 (0)