Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/google/adk/tools/mcp_tool/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import logging
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import List
Expand Down Expand Up @@ -142,7 +143,10 @@ def __init__(
auth_credential: Optional[AuthCredential] = None,
require_confirmation: Union[bool, Callable[..., bool]] = False,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
Callable[
[ReadonlyContext],
Union[Dict[str, str], Awaitable[Dict[str, str]]],
]
] = None,
progress_callback: Optional[
Union[ProgressFnT, ProgressCallbackFactory]
Expand Down Expand Up @@ -379,6 +383,8 @@ async def _run_async_impl(
dynamic_headers = self._header_provider(
ReadonlyContext(tool_context._invocation_context)
)
if inspect.isawaitable(dynamic_headers):
dynamic_headers = await dynamic_headers

headers: Dict[str, str] = {}
if auth_headers:
Expand Down
8 changes: 7 additions & 1 deletion src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import base64
import inspect
import logging
import sys
from typing import Any
Expand Down Expand Up @@ -108,7 +109,10 @@ def __init__(
auth_credential: Optional[AuthCredential] = None,
require_confirmation: Union[bool, Callable[..., bool]] = False,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
Callable[
[ReadonlyContext],
Union[Dict[str, str], Awaitable[Dict[str, str]]],
]
] = None,
progress_callback: Optional[
Union[ProgressFnT, ProgressCallbackFactory]
Expand Down Expand Up @@ -293,6 +297,8 @@ async def _execute_with_session(
# Add headers from header_provider if available
if self._header_provider and readonly_context:
provider_headers = self._header_provider(readonly_context)
if inspect.isawaitable(provider_headers):
provider_headers = await provider_headers
if provider_headers:
headers.update(provider_headers)

Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,41 @@ async def test_run_async_impl_with_header_provider_no_auth(self):
"test_tool", arguments=args, progress_callback=None, meta=None
)

@pytest.mark.asyncio
async def test_run_async_impl_with_async_header_provider_no_auth(self):
"""Test running tool with an async header_provider but no auth."""
expected_headers = {"X-Tenant-ID": "test-tenant"}

async def header_provider(_context):
return expected_headers

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
header_provider=header_provider,
)

mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)

tool_context = Mock(spec=ToolContext)
tool_context._invocation_context = Mock()
args = {"param1": "test_value"}

result = await tool._run_async_impl(
args=args, tool_context=tool_context, credential=None
)

assert result == mcp_response.model_dump(exclude_none=True, mode="json")
self.mock_session_manager.create_session.assert_called_once_with(
headers=expected_headers
)
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=None, meta=None
)

@pytest.mark.asyncio
async def test_run_async_impl_with_header_provider_and_oauth2(self):
"""Test running tool with header_provider and OAuth2 auth."""
Expand Down
26 changes: 26 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,32 @@ async def test_get_tools_with_header_provider(self):
headers=expected_headers
)

@pytest.mark.asyncio
async def test_get_tools_with_async_header_provider(self):
"""Test get_tools with an async header_provider."""
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
mock_readonly_context = Mock(spec=ReadonlyContext)
expected_headers = {"X-Tenant-ID": "test-tenant"}

async def header_provider(_context):
return expected_headers

toolset = McpToolset(
connection_params=self.mock_stdio_params,
header_provider=header_provider,
)
toolset._mcp_session_manager = self.mock_session_manager

tools = await toolset.get_tools(readonly_context=mock_readonly_context)

assert len(tools) == 2
self.mock_session_manager.create_session.assert_called_once_with(
headers=expected_headers
)

@pytest.mark.asyncio
async def test_close_success(self):
"""Test successful cleanup."""
Expand Down