diff --git a/docs/content/3.agent/2.a2a-agent.md b/docs/content/3.agent/2.a2a-agent.md index 0f48faa8..6cceb3d8 100644 --- a/docs/content/3.agent/2.a2a-agent.md +++ b/docs/content/3.agent/2.a2a-agent.md @@ -18,7 +18,7 @@ navigation: 我们将借助 Google ADK 的工具函数来便捷地创建一个 A2A Server: ```python [server_agent.py] -from google.adk.a2a.utils.agent_to_a2a import to_a2a +from veadk.a2a.utils.agent_to_a2a import to_a2a from veadk import Agent from veadk.tools.demo_tools import get_city_weather @@ -27,6 +27,10 @@ agent = Agent(name="weather_reporter", tools=[get_city_weather]) app = to_a2a(agent) ``` +::callout{icon="i-lucide-info"} +默认情况下,A2A Server 不启用认证功能。如果需要启用 VeADK 的认证和凭据管理功能,请参考下面的 [启用认证功能](#启用认证功能) 章节。 +:: + ### 本地启动 A2A Server ```bash [Terminal] @@ -51,8 +55,51 @@ print(response) # 北京天气晴朗,气温25°C。 :: +## 启用认证功能 + +VeADK 提供了内置的认证和凭据管理功能,可以在 A2A Server 和 Client 之间进行安全的身份验证和凭据传递。 + +### Server 侧启用认证 + +在创建 A2A Server 时,通过设置 `enable_auth=True` 来启用认证功能: + +```python [server_agent.py] +from veadk.a2a.utils.agent_to_a2a import to_a2a +from veadk import Agent +from veadk.tools.demo_tools import get_city_weather + +agent = Agent(name="weather_reporter", tools=[get_city_weather]) + +# 启用 VeADK 认证功能 +app = to_a2a(agent, enable_auth=True) +``` + +启用认证后,Server 会: +- 自动创建 `VeCredentialService` 来管理凭据 +- 添加认证中间件来验证请求中的 token +- 支持凭据在 Server 和 Client 之间的安全传递 + +### 认证方式 + +`to_a2a` 支持两种认证方式,通过 `auth_method` 参数指定: + +```python +# 方式 1: 从 HTTP Header 中提取 token (默认) +app = to_a2a(agent, enable_auth=True, auth_method="header") + +# 方式 2: 从 Query String 中提取 token +app = to_a2a(agent, enable_auth=True, auth_method="querystring") +``` + + +### Client 侧使用认证 + +当 Server 启用认证后,Client 侧的 `RemoteVeAgent` 会**自动处理认证** + ## 初始化选项 +### RemoteVeAgent 参数 + ::field-group ::field{name="name" type="string"} 智能体的名称 @@ -62,3 +109,41 @@ print(response) # 北京天气晴朗,气温25°C。 远程智能体的访问端点 :: :: + +### to_a2a 参数 + +::field-group + ::field{name="agent" type="BaseAgent" required} + 要转换为 A2A Server 的智能体实例 + :: + + ::field{name="host" type="string" default="localhost"} + A2A Server 的主机地址 + :: + + ::field{name="port" type="int" default="8000"} + A2A Server 的端口号 + :: + + ::field{name="protocol" type="string" default="http"} + A2A Server 的协议(http 或 https) + :: + + ::field{name="agent_card" type="AgentCard | string"} + 可选的智能体卡片对象或 JSON 文件路径。如果不提供,将自动从智能体生成 + :: + + ::field{name="runner" type="Runner"} + 可选的 Runner 对象。如果不提供,将自动创建默认 Runner + :: + + ::field{name="enable_auth" type="bool" default="false"} + 是否启用 VeADK 认证功能。启用后会添加凭据服务和认证中间件 + :: + + ::field{name="auth_method" type="'header' | 'querystring'" default="header"} + 认证方式。仅在 `enable_auth=True` 时有效 + - `header`: 从 Authorization header 中提取 token + - `querystring`: 从 query parameter 中提取 token + :: +:: diff --git a/tests/auth/test_credential_service.py b/tests/auth/test_credential_service.py new file mode 100644 index 00000000..2ba39b38 --- /dev/null +++ b/tests/auth/test_credential_service.py @@ -0,0 +1,364 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for VeCredentialService.""" + +import pytest +from unittest.mock import Mock + +from google.adk.auth.auth_credential import ( + AuthCredential, + AuthCredentialTypes, + HttpAuth, + HttpCredentials, +) +from google.adk.auth.auth_tool import AuthConfig +from google.adk.agents.callback_context import CallbackContext + +from veadk.auth.ve_credential_service import VeCredentialService + + +@pytest.fixture +def credential_service(): + """Create a VeCredentialService instance for testing.""" + return VeCredentialService() + + +@pytest.fixture +def sample_auth_credential(): + """Create a sample AuthCredential for testing (HTTP Bearer type).""" + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="test_token_123"), + ), + ) + + +@pytest.fixture +def sample_api_key_credential(): + """Create a sample API Key AuthCredential for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test_api_key_123", + ) + + +@pytest.fixture +def mock_callback_context(): + """Create a mock CallbackContext.""" + mock_ctx = Mock(spec=CallbackContext) + mock_ctx._invocation_context = Mock() + mock_ctx._invocation_context.app_name = "test_app" + mock_ctx._invocation_context.user_id = "user123" + return mock_ctx + + +@pytest.fixture +def mock_auth_config(): + """Create a mock AuthConfig.""" + mock_config = Mock(spec=AuthConfig) + mock_config.credential_key = "test_key" + mock_config.exchanged_auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="test_token"), + ), + ) + return mock_config + + +class TestVeCredentialService: + """Tests for VeCredentialService class.""" + + def test_initialization(self, credential_service): + """Test service initialization.""" + assert credential_service._credentials == {} + + @pytest.mark.asyncio + async def test_set_and_get_credential( + self, credential_service, sample_auth_credential + ): + """Test setting and getting credentials.""" + # Set credential + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + credential=sample_auth_credential, + ) + + # Get credential + retrieved = await credential_service.get_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + ) + + assert retrieved is not None + assert retrieved.http.credentials.token == "test_token_123" + + @pytest.mark.asyncio + async def test_get_nonexistent_credential(self, credential_service): + """Test getting a credential that doesn't exist.""" + credential = await credential_service.get_credential( + app_name="nonexistent_app", + user_id="nonexistent_user", + credential_key="nonexistent_key", + ) + + assert credential is None + + @pytest.mark.asyncio + async def test_multiple_users_same_app(self, credential_service): + """Test storing credentials for multiple users in the same app.""" + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="token_user1"), + ), + ) + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="token_user2"), + ), + ) + + # Set credentials for two different users + await credential_service.set_credential( + app_name="test_app", + user_id="user1", + credential_key="bearer_token", + credential=cred1, + ) + await credential_service.set_credential( + app_name="test_app", + user_id="user2", + credential_key="bearer_token", + credential=cred2, + ) + + # Verify both credentials are stored separately + retrieved1 = await credential_service.get_credential( + app_name="test_app", + user_id="user1", + credential_key="bearer_token", + ) + retrieved2 = await credential_service.get_credential( + app_name="test_app", + user_id="user2", + credential_key="bearer_token", + ) + + assert retrieved1.http.credentials.token == "token_user1" + assert retrieved2.http.credentials.token == "token_user2" + + @pytest.mark.asyncio + async def test_multiple_credential_keys_same_user(self, credential_service): + """Test storing multiple credential keys for the same user.""" + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="bearer_token_123"), + ), + ) + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="api_key_456", + ) + + # Set different credential types for the same user + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + credential=cred1, + ) + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="api_key", + credential=cred2, + ) + + # Verify both credentials are stored + retrieved1 = await credential_service.get_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + ) + retrieved2 = await credential_service.get_credential( + app_name="test_app", + user_id="user123", + credential_key="api_key", + ) + + assert retrieved1.http.credentials.token == "bearer_token_123" + assert retrieved2.api_key == "api_key_456" + + @pytest.mark.asyncio + async def test_save_credential_via_adk_interface( + self, credential_service, mock_callback_context, mock_auth_config + ): + """Test saving credential via ADK BaseCredentialService interface.""" + # Save credential using ADK interface + await credential_service.save_credential( + auth_config=mock_auth_config, + callback_context=mock_callback_context, + ) + + # Verify credential was stored + credential = await credential_service.get_credential( + app_name="test_app", + user_id="user123", + credential_key="test_key", + ) + + assert credential is not None + assert credential.http.credentials.token == "test_token" + + @pytest.mark.asyncio + async def test_load_credential_via_adk_interface( + self, + credential_service, + mock_callback_context, + mock_auth_config, + sample_auth_credential, + ): + """Test loading credential via ADK BaseCredentialService interface.""" + # First set a credential + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="test_key", + credential=sample_auth_credential, + ) + + # Load credential using ADK interface + loaded = await credential_service.load_credential( + auth_config=mock_auth_config, + callback_context=mock_callback_context, + ) + + assert loaded is not None + assert loaded.http.credentials.token == "test_token_123" + + @pytest.mark.asyncio + async def test_load_nonexistent_credential_via_adk_interface( + self, credential_service, mock_callback_context, mock_auth_config + ): + """Test loading a nonexistent credential via ADK interface.""" + # Try to load a credential that doesn't exist + loaded = await credential_service.load_credential( + auth_config=mock_auth_config, + callback_context=mock_callback_context, + ) + + assert loaded is None + + @pytest.mark.asyncio + async def test_overwrite_existing_credential(self, credential_service): + """Test overwriting an existing credential.""" + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="old_token"), + ), + ) + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="new_token"), + ), + ) + + # Set initial credential + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + credential=cred1, + ) + + # Overwrite with new credential + await credential_service.set_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + credential=cred2, + ) + + # Verify new credential replaced the old one + retrieved = await credential_service.get_credential( + app_name="test_app", + user_id="user123", + credential_key="bearer_token", + ) + + assert retrieved.http.credentials.token == "new_token" + + @pytest.mark.asyncio + async def test_credential_isolation_between_apps(self, credential_service): + """Test that credentials are isolated between different apps.""" + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="app1_token"), + ), + ) + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="app2_token"), + ), + ) + + # Set credentials for two different apps + await credential_service.set_credential( + app_name="app1", + user_id="user123", + credential_key="bearer_token", + credential=cred1, + ) + await credential_service.set_credential( + app_name="app2", + user_id="user123", + credential_key="bearer_token", + credential=cred2, + ) + + # Verify credentials are isolated + retrieved1 = await credential_service.get_credential( + app_name="app1", + user_id="user123", + credential_key="bearer_token", + ) + retrieved2 = await credential_service.get_credential( + app_name="app2", + user_id="user123", + credential_key="bearer_token", + ) + + assert retrieved1.http.credentials.token == "app1_token" + assert retrieved2.http.credentials.token == "app2_token" diff --git a/tests/test_ve_a2a_middlewares.py b/tests/test_ve_a2a_middlewares.py new file mode 100644 index 00000000..1e5de3b4 --- /dev/null +++ b/tests/test_ve_a2a_middlewares.py @@ -0,0 +1,321 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for A2A authentication middleware.""" + +import pytest +from unittest.mock import Mock, patch +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response + +from veadk.a2a.ve_middlewares import A2AAuthMiddleware, build_a2a_auth_middleware +from veadk.auth.ve_credential_service import VeCredentialService +from veadk.utils.auth import VE_TIP_TOKEN_HEADER + + +@pytest.fixture +def credential_service(): + """Create a VeCredentialService instance for testing.""" + return VeCredentialService() + + +@pytest.fixture +def mock_identity_client(): + """Create a mock IdentityClient.""" + mock_client = Mock() + mock_client.get_workload_access_token = Mock() + return mock_client + + +@pytest.fixture +def sample_jwt_token(): + """Sample JWT token for testing.""" + # This is a sample JWT with sub="user123" and act claim + # Header: {"alg": "HS256", "typ": "JWT"} + # Payload: {"sub": "user123", "act": {"sub": "agent1"}} + return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiYWN0Ijp7InN1YiI6ImFnZW50MSJ9fQ.signature" + + +class TestA2AAuthMiddleware: + """Tests for A2AAuthMiddleware class.""" + + def test_middleware_initialization(self, credential_service, mock_identity_client): + """Test middleware initialization with all parameters.""" + app = Starlette() + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + token_param="token", + credential_key="test_key", + identity_client=mock_identity_client, + ) + + assert middleware.app_name == "test_app" + assert middleware.credential_service == credential_service + assert middleware.auth_method == "header" + assert middleware.token_param == "token" + assert middleware.credential_key == "test_key" + assert middleware.identity_client == mock_identity_client + + def test_middleware_default_identity_client(self, credential_service): + """Test middleware uses global identity client when not provided.""" + app = Starlette() + + with patch( + "veadk.a2a.ve_middlewares.get_default_identity_client" + ) as mock_get_client: + mock_client = Mock() + mock_get_client.return_value = mock_client + + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + ) + + mock_get_client.assert_called_once() + assert middleware.identity_client == mock_client + + def test_extract_token_from_header_with_bearer(self, credential_service): + """Test extracting token from Authorization header with Bearer prefix.""" + app = Starlette() + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + ) + + # Create mock request + mock_request = Mock(spec=Request) + mock_request.headers = {"Authorization": "Bearer test_token_123"} + + token, has_prefix = middleware._extract_token(mock_request) + + assert token == "test_token_123" + assert has_prefix is True + + def test_extract_token_from_header_without_bearer(self, credential_service): + """Test extracting token from Authorization header without Bearer prefix.""" + app = Starlette() + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {"Authorization": "test_token_123"} + + token, has_prefix = middleware._extract_token(mock_request) + + assert token == "test_token_123" + assert has_prefix is False + + def test_extract_token_from_query_string(self, credential_service): + """Test extracting token from query string.""" + app = Starlette() + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="querystring", + token_param="access_token", + ) + + mock_request = Mock(spec=Request) + mock_request.query_params = {"access_token": "test_token_123"} + + token, has_prefix = middleware._extract_token(mock_request) + + assert token == "test_token_123" + assert has_prefix is False + + def test_extract_token_no_token_found(self, credential_service): + """Test extracting token when no token is present.""" + app = Starlette() + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {} + + token, has_prefix = middleware._extract_token(mock_request) + + assert token is None + assert has_prefix is False + + @pytest.mark.asyncio + async def test_dispatch_with_valid_jwt_token( + self, credential_service, mock_identity_client, sample_jwt_token + ): + """Test dispatch with valid JWT token.""" + app = Starlette() + + # Mock WorkloadToken + mock_workload_token = Mock() + mock_workload_token.workload_access_token = "workload_token_123" + mock_workload_token.expires_at = 1234567890 + mock_identity_client.get_workload_access_token.return_value = ( + mock_workload_token + ) + + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + identity_client=mock_identity_client, + ) + + # Create mock request with JWT token + mock_request = Mock(spec=Request) + mock_request.headers = { + "Authorization": f"Bearer {sample_jwt_token}", + } + mock_request.scope = {} + + # Mock call_next + async def mock_call_next(request): + return Response("OK", status_code=200) + + # Execute dispatch + with patch( + "veadk.a2a.ve_middlewares.extract_delegation_chain_from_jwt" + ) as mock_extract: + mock_extract.return_value = ("user123", ["agent1"]) + + with patch( + "veadk.a2a.ve_middlewares.build_auth_config" + ) as mock_build_config: + mock_auth_config = Mock() + mock_auth_config.exchanged_auth_credential = Mock() + mock_build_config.return_value = mock_auth_config + + response = await middleware.dispatch(mock_request, mock_call_next) + + # Verify response + assert response.status_code == 200 + + # Verify user was set in request scope + assert "user" in mock_request.scope + assert mock_request.scope["user"].username == "user123" + + # Verify workload token was set + assert "auth" in mock_request.scope + assert mock_request.scope["auth"] == mock_workload_token + + @pytest.mark.asyncio + async def test_dispatch_with_tip_token( + self, credential_service, mock_identity_client, sample_jwt_token + ): + """Test dispatch with TIP token in header.""" + app = Starlette() + + # Mock WorkloadToken + mock_workload_token = Mock() + mock_workload_token.workload_access_token = "workload_token_from_tip" + mock_workload_token.expires_at = 1234567890 + mock_identity_client.get_workload_access_token.return_value = ( + mock_workload_token + ) + + middleware = A2AAuthMiddleware( + app=app, + app_name="test_app", + credential_service=credential_service, + auth_method="header", + identity_client=mock_identity_client, + ) + + # Create mock request with both JWT and TIP token + tip_token = "tip_token_123" + mock_request = Mock(spec=Request) + mock_request.headers = { + "Authorization": f"Bearer {sample_jwt_token}", + VE_TIP_TOKEN_HEADER: tip_token, + } + mock_request.scope = {} + + # Mock call_next + async def mock_call_next(request): + return Response("OK", status_code=200) + + # Execute dispatch + with patch( + "veadk.a2a.ve_middlewares.extract_delegation_chain_from_jwt" + ) as mock_extract: + mock_extract.return_value = ("user123", ["agent1"]) + + with patch( + "veadk.a2a.ve_middlewares.build_auth_config" + ) as mock_build_config: + mock_auth_config = Mock() + mock_auth_config.exchanged_auth_credential = Mock() + mock_build_config.return_value = mock_auth_config + + _ = await middleware.dispatch(mock_request, mock_call_next) + + # Verify TIP token was used for workload token exchange + mock_identity_client.get_workload_access_token.assert_called_once_with( + user_token=tip_token, user_id="user123" + ) + + # Verify workload token was set + assert mock_request.scope["auth"] == mock_workload_token + + +class TestBuildA2AAuthMiddleware: + """Tests for build_a2a_auth_middleware factory function.""" + + def test_build_middleware_basic(self, credential_service): + """Test building middleware with basic parameters.""" + middleware_class = build_a2a_auth_middleware( + app_name="test_app", + credential_service=credential_service, + ) + + assert middleware_class is not None + assert issubclass(middleware_class, A2AAuthMiddleware) + + def test_build_middleware_with_all_params( + self, credential_service, mock_identity_client + ): + """Test building middleware with all parameters.""" + middleware_class = build_a2a_auth_middleware( + app_name="test_app", + credential_service=credential_service, + auth_method="querystring", + token_param="access_token", + credential_key="custom_key", + identity_client=mock_identity_client, + ) + + # Create instance to verify parameters + app = Starlette() + instance = middleware_class(app) + + assert instance.app_name == "test_app" + assert instance.auth_method == "querystring" + assert instance.token_param == "access_token" + assert instance.credential_key == "custom_key" + assert instance.identity_client == mock_identity_client diff --git a/veadk/a2a/remote_ve_agent.py b/veadk/a2a/remote_ve_agent.py index 14e7e3d2..1eddedaa 100644 --- a/veadk/a2a/remote_ve_agent.py +++ b/veadk/a2a/remote_ve_agent.py @@ -13,14 +13,22 @@ # limitations under the License. import json -from typing import Literal, Optional +import functools +from typing import AsyncGenerator, Literal, Optional +from a2a.client.base_client import BaseClient import httpx import requests from a2a.types import AgentCard from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from veadk.integrations.ve_identity.utils import generate_headers +from veadk.utils.auth import VE_TIP_TOKEN_CREDENTIAL_KEY, VE_TIP_TOKEN_HEADER from veadk.utils.logger import get_logger +from google.adk.utils.context_utils import Aclosing +from google.adk.events.event import Event +from google.adk.agents.invocation_context import InvocationContext + logger = get_logger(__name__) @@ -60,12 +68,16 @@ class RemoteVeAgent(RemoteA2aAgent): with a configured `base_url` is provided. If both are given, they must not conflict. auth_token (Optional[str]): - Optional authentication token used for secure access. If not provided, - the agent will be accessed without authentication. + Optional authentication token used for secure access during initialization. + If not provided, the agent will be accessed without authentication. + Note: For runtime authentication, use the credential service in InvocationContext. auth_method (Literal["header", "querystring"] | None): - The method of attaching the authentication token. - - `"header"`: Token is passed via HTTP `Authorization` header. - - `"querystring"`: Token is passed as a query parameter. + The method of attaching the authentication token at runtime. + - `"header"`: Token is retrieved from credential service and passed via HTTP `Authorization` header. + - `"querystring"`: Token is retrieved from credential service and passed as a query parameter. + - `None`: No runtime authentication injection (default). + The credential is loaded from `InvocationContext.credential_service` using the + app_name and user_id from the context. httpx_client (Optional[httpx.AsyncClient]): An optional, pre-configured `httpx.AsyncClient` to use for communication. This allows for client sharing and advanced configurations (e.g., proxies). @@ -81,13 +93,13 @@ class RemoteVeAgent(RemoteA2aAgent): Examples: ```python - # Example 1: Connect using a URL + # Example 1: Connect using a URL (no authentication) agent = RemoteVeAgent( name="public_agent", url="https://vefaas.example.com/agents/public" ) - # Example 2: Using Bearer token in header + # Example 2: Using static Bearer token in header for initialization agent = RemoteVeAgent( name="secured_agent", url="https://vefaas.example.com/agents/secure", @@ -95,7 +107,15 @@ class RemoteVeAgent(RemoteA2aAgent): auth_method="header" ) - # Example 3: Using a pre-configured httpx_client + # Example 3: Using runtime authentication with credential service + # The auth token will be automatically injected from InvocationContext.credential_service + agent = RemoteVeAgent( + name="dynamic_auth_agent", + url="https://vefaas.example.com/agents/secure", + auth_method="header" # Will load credential at runtime + ) + + # Example 4: Using a pre-configured httpx_client import httpx client = httpx.AsyncClient( base_url="https://vefaas.example.com/agents/query", @@ -103,13 +123,14 @@ class RemoteVeAgent(RemoteA2aAgent): ) agent = RemoteVeAgent( name="query_agent", - auth_token="my_secret_token", - auth_method="querystring", + auth_method="querystring", # Will load credential at runtime httpx_client=client ) ``` """ + auth_method: Literal["header", "querystring"] | None = None + def __init__( self, name: str, @@ -196,3 +217,174 @@ def __init__( # are properly cleaned up. if not client_was_provided: self._httpx_client_needs_cleanup = True + + # Set auth_method if provided + if auth_method: + self.auth_method = auth_method + + # Wrap _run_async_impl with pre-run hook to ensure initialization + # and authentication logic always executes, even if users override _run_async_impl + self._wrap_run_async_impl() + + def _wrap_run_async_impl(self) -> None: + """Wrap _run_async_impl with a decorator that ensures pre-run logic executes. + + This method wraps the _run_async_impl method with a decorator that: + 1. Executes _pre_run before the actual implementation + 2. Handles errors from _pre_run and yields error events + 3. Ensures the wrapper works even if users override _run_async_impl + + The wrapper is applied by replacing the bound method on the instance. + """ + # Store the original _run_async_impl method + original_run_async_impl = self._run_async_impl + + @functools.wraps(original_run_async_impl) + async def wrapped_run_async_impl( + ctx: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Wrapped version of _run_async_impl with pre-run hook.""" + # Execute pre-run initialization + try: + await self._pre_run(ctx) + except Exception as e: + yield Event( + author=self.name, + error_message=f"Failed to initialize remote A2A agent: {e}", + invocation_id=ctx.invocation_id, + branch=ctx.branch, + ) + return + + # Call the original (or overridden) _run_async_impl + async with Aclosing(original_run_async_impl(ctx)) as agen: + async for event in agen: + yield event + + # Replace the instance method with the wrapped version + self._run_async_impl = wrapped_run_async_impl + + async def _pre_run(self, ctx: InvocationContext) -> None: + """Pre-run initialization and authentication setup. + + This method is called before the actual agent execution to: + 1. Ensure the agent is resolved (agent card fetched, client initialized) + 2. Inject authentication token from credential service if available + + This method is separated from _run_async_impl to ensure these critical + initialization steps are always executed, even if users override _run_async_impl. + + Args: + ctx: Invocation context containing session and user information + + Raises: + Exception: If agent initialization fails + """ + # Ensure agent is resolved + await self._ensure_resolved() + + # Inject auth token if credential service is available + await self._inject_auth_token(ctx) + + async def _inject_auth_token(self, ctx: InvocationContext) -> None: + """Inject authentication token from credential service into the HTTP client. + + This method retrieves the authentication token from the credential service + in the InvocationContext and updates the HTTP client headers or query params + based on the configured auth_method. + + Args: + ctx: Invocation context containing credential service and user information + """ + # Skip if no credential service in context + if not ctx.credential_service: + logger.debug( + "No credential service in InvocationContext, skipping auth token injection" + ) + return + + # Skip if client is not initialized or not a BaseClient + if not hasattr(self, "_a2a_client") or not isinstance( + self._a2a_client, BaseClient + ): + logger.debug( + "A2A client not initialized or not a BaseClient, skipping auth token injection" + ) + return + + # Skip if transport is not available + if not hasattr(self._a2a_client, "_transport"): + logger.debug( + "A2A client transport not available, skipping auth token injection" + ) + return + + # Skip if httpx_client is not available + if not hasattr(self._a2a_client._transport, "httpx_client"): + logger.debug( + "A2A client httpx_client not available, skipping auth token injection" + ) + return + + try: + from veadk.utils.auth import build_auth_config + from google.adk.agents.callback_context import CallbackContext + + # Inject TIP token via header + workload_auth_config = build_auth_config( + auth_method="apikey", + credential_key=VE_TIP_TOKEN_CREDENTIAL_KEY, + header_name=VE_TIP_TOKEN_HEADER, + ) + + tip_credential = await ctx.credential_service.load_credential( + auth_config=workload_auth_config, + callback_context=CallbackContext(ctx), + ) + + if tip_credential: + self._a2a_client._transport.httpx_client.headers.update( + {VE_TIP_TOKEN_HEADER: tip_credential.api_key} + ) + logger.debug( + f"Injected TIP token via header for app={ctx.app_name}, user={ctx.user_id}" + ) + + # Build auth config based on auth_method + auth_config = build_auth_config( + credential_key="inbound_auth", + auth_method=self.auth_method or "header", + header_scheme="bearer", + ) + + # Load credential from credential service + credential = await ctx.credential_service.load_credential( + auth_config=auth_config, + callback_context=CallbackContext(ctx), + ) + + if not credential: + logger.debug( + f"No credential loaded, skipping auth token injection for app={ctx.app_name}, user={ctx.user_id}" + ) + return + + # Inject credential based on auth_method + if self.auth_method == "querystring": + # Extract API key + api_key = credential.api_key + new_params = dict(self._a2a_client._transport.httpx_client.params) + new_params.update({"token": api_key}) + self._a2a_client._transport.httpx_client.params = new_params + logger.debug( + f"Injected auth token via querystring for app={ctx.app_name}, user={ctx.user_id}" + ) + else: + if headers := generate_headers(credential): + self._a2a_client._transport.httpx_client.headers.update(headers) + logger.debug( + f"Injected auth token via header for app={ctx.app_name}, user={ctx.user_id}" + ) + + except Exception as e: + logger.warning(f"Failed to inject auth token: {e}", exc_info=True) diff --git a/veadk/a2a/utils/__init__.py b/veadk/a2a/utils/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/a2a/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/a2a/utils/agent_to_a2a.py b/veadk/a2a/utils/agent_to_a2a.py new file mode 100644 index 00000000..ebd104ed --- /dev/null +++ b/veadk/a2a/utils/agent_to_a2a.py @@ -0,0 +1,170 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Literal, Optional, Union + +from a2a.types import AgentCard +from google.adk.agents import BaseAgent +from starlette.applications import Starlette +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.a2a.utils.agent_to_a2a import to_a2a as google_adk_to_a2a +from veadk import Runner +from veadk.a2a.ve_middlewares import build_a2a_auth_middleware +from veadk.auth.ve_credential_service import VeCredentialService +from veadk.consts import DEFAULT_AGENT_NAME + + +def to_a2a( + agent: BaseAgent, + *, + host: str = "localhost", + port: int = 8000, + protocol: str = "http", + agent_card: Optional[Union[AgentCard, str]] = None, + runner: Optional[Runner] = None, + enable_auth: bool = False, + auth_method: Literal["header", "querystring"] = "header", +) -> Starlette: + """Convert an ADK agent to a A2A Starlette application with optional VeADK enhancements. + + This function wraps Google ADK's to_a2a utility and optionally adds: + - VeCredentialService for authentication management + - A2A authentication middleware for token validation + + Args: + agent: The ADK agent to convert to A2A server + host: The host for the A2A RPC URL (default: "localhost") + port: The port for the A2A RPC URL (default: 8000) + protocol: The protocol for the A2A RPC URL (default: "http") + agent_card: Optional pre-built AgentCard object or path to agent card + JSON. If not provided, will be built automatically from the + agent. + runner: Optional pre-built Runner object. If not provided, a default + runner will be created using in-memory services. + When enable_auth=True: + - If runner is provided and has a credential_service, it must be + a VeCredentialService instance (raises TypeError otherwise) + - If runner is provided without credential_service, a new + VeCredentialService will be created and set + - If runner is not provided, a new runner with VeCredentialService + will be created + auth_method: Authentication method for A2A requests (only used when + enable_auth=True). Options: + - "header": Extract token from Authorization header (default) + - "querystring": Extract token from query parameter + enable_auth: Whether to enable VeADK authentication features. + When True, enables credential service and auth middleware. + When False, uses standard Google ADK behavior. + Default: False + + Returns: + A Starlette application that can be run with uvicorn + + Raises: + TypeError: If enable_auth=True and runner has a credential_service + that is not a VeCredentialService instance + + Example: + Basic usage (without VeADK auth): + ```python + from veadk import Agent + from veadk.a2a.utils.agent_to_a2a import to_a2a + + agent = Agent(name="my_agent", tools=[...]) + app = to_a2a(agent, host="localhost", port=8000) + # Run with: uvicorn module:app --host localhost --port 8000 + ``` + + With VeADK authentication enabled: + ```python + app = to_a2a(agent, enable_auth=True) + ``` + + With custom runner and VeADK auth: + ```python + from veadk import Agent, Runner + from veadk.memory.short_term_memory import ShortTermMemory + from veadk.auth.ve_credential_service import VeCredentialService + + agent = Agent(name="my_agent") + runner = Runner( + agent=agent, + short_term_memory=ShortTermMemory(), + app_name="my_app", + credential_service=VeCredentialService() # Optional + ) + app = to_a2a(agent, runner=runner, enable_auth=True) + ``` + + With querystring authentication: + ```python + app = to_a2a(agent, enable_auth=True, auth_method="querystring") + ``` + """ + app_name = agent.name or DEFAULT_AGENT_NAME + middleware = None # May need support multiple middlewares in the future + + # Handle VeADK authentication setup + if enable_auth: + # Create credential service if not provided + credential_service = VeCredentialService() + if runner is not None: + # Check if runner has credential_service + if runner.credential_service is not None: + # Validate that it's a VeCredentialService + if not isinstance(runner.credential_service, VeCredentialService): + raise TypeError( + f"When enable_auth=True, runner.credential_service must be " + f"a VeCredentialService instance, got {type(runner.credential_service).__name__}" + ) + # Use existing credential service + credential_service = runner.credential_service + else: + # Add credential_service to runner + runner.credential_service = credential_service + else: + # Create runner with credential_service + runner = Runner( + app_name=app_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=credential_service, + ) + + middleware = build_a2a_auth_middleware( + app_name=app_name, + credential_service=credential_service, + auth_method=auth_method, + ) + + # Convert agent to A2A Starlette app using Google ADK utility + app: Starlette = google_adk_to_a2a( + agent=agent, + host=host, + port=port, + protocol=protocol, + agent_card=agent_card, + runner=runner, + ) + + # Add VeADK authentication middleware only if enabled + if middleware: + app.add_middleware(middleware) + + return app diff --git a/veadk/a2a/ve_a2a_server.py b/veadk/a2a/ve_a2a_server.py index 95e53c1a..f5fb5438 100644 --- a/veadk/a2a/ve_a2a_server.py +++ b/veadk/a2a/ve_a2a_server.py @@ -23,11 +23,19 @@ from veadk.memory.short_term_memory import ShortTermMemory from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.auth.credential_service.base_credential_service import ( + BaseCredentialService, +) class VeA2AServer: def __init__( - self, agent: Agent, url: str, app_name: str, short_term_memory: ShortTermMemory + self, + agent: Agent, + url: str, + app_name: str, + short_term_memory: ShortTermMemory, + credential_service: BaseCredentialService | None = None, ): self.agent_card = get_agent_card(agent, url) @@ -36,7 +44,8 @@ def __init__( agent=agent, app_name=app_name, short_term_memory=short_term_memory, - ), + credential_service=credential_service, + ) ) self.task_store = InMemoryTaskStore() @@ -56,7 +65,11 @@ def build(self) -> FastAPI: def init_app( - server_url: str, app_name: str, agent: Agent, short_term_memory: ShortTermMemory + server_url: str, + app_name: str, + agent: Agent, + short_term_memory: ShortTermMemory, + credential_service: BaseCredentialService | None = None, ) -> FastAPI: """Init the fastapi application in terms of VeADK agent. @@ -75,5 +88,6 @@ def init_app( url=server_url, app_name=app_name, short_term_memory=short_term_memory, + credential_service=credential_service, ) return server.build() diff --git a/veadk/a2a/ve_middlewares.py b/veadk/a2a/ve_middlewares.py new file mode 100644 index 00000000..3da16d52 --- /dev/null +++ b/veadk/a2a/ve_middlewares.py @@ -0,0 +1,313 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A Authentication Middleware for FastAPI/Starlette. + +This module provides middleware for extracting authentication credentials +from HTTP requests and storing them in the credential service. +""" + +import logging +from typing import Callable, Literal, Optional + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from volcenginesdkcore.rest import ApiException + + +from veadk.auth.ve_credential_service import VeCredentialService +from veadk.utils.auth import ( + extract_delegation_chain_from_jwt, + build_auth_config, + VE_TIP_TOKEN_HEADER, + VE_TIP_TOKEN_CREDENTIAL_KEY, +) +from veadk.integrations.ve_identity import ( + WorkloadToken, + IdentityClient, + get_default_identity_client, +) + +logger = logging.getLogger(__name__) + + +class A2AAuthMiddleware(BaseHTTPMiddleware): + """Middleware to extract and store authentication credentials from requests. + + This middleware: + 1. Extracts auth tokens from Authorization header or query string + 2. Parses JWT tokens to extract user_id and delegation chain + 3. Builds AuthConfig based on the authentication method + 4. Stores credentials in the credential service + 5. Extracts TIP token from X-Ve-TIP-Token header for trust propagation + 6. Exchanges TIP token for workload access token using IdentityClient + 7. Sets workload token in request.scope["auth"] for downstream use + + Examples: + ```python + from fastapi import FastAPI + from veadk.a2a.ve_middlewares import build_a2a_auth_middleware + from veadk.auth.ve_credential_service import VeCredentialService + + app = FastAPI() + credential_service = VeCredentialService() + + # Add middleware with Authorization header support + app.add_middleware( + build_a2a_auth_middleware( + app_name="my_app", + credential_service=credential_service, + auth_method="header" + ) + ) + + # Or with query string support + app.add_middleware( + build_a2a_auth_middleware( + app_name="my_app", + credential_service=credential_service, + auth_method="querystring", + token_param="token" + ) + ) + ``` + """ + + def __init__( + self, + app, + app_name: str, + credential_service: VeCredentialService, + auth_method: Literal["header", "querystring"] = "header", + token_param: str = "token", + credential_key: str = "inbound_auth", + identity_client: Optional[IdentityClient] = None, + ): + """Initialize the middleware. + + Args: + app: The ASGI application + app_name: Application name for credential storage + credential_service: Credential service to store credentials + auth_method: Authentication method - "header" or "querystring" + token_param: Query parameter name for token (when auth_method="querystring") + credential_key: Key to identify the credential in the store + identity_client: Optional IdentityClient for TIP token exchange. + If not provided, uses the global IdentityClient from VeIdentityConfig. + """ + super().__init__(app) + self.app_name = app_name + self.credential_service = credential_service + self.auth_method = auth_method + self.token_param = token_param + self.credential_key = credential_key + self.identity_client = identity_client or get_default_identity_client() + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process the request and extract authentication credentials. + + This method: + 1. Extracts token from Authorization header or query string + 2. Parses JWT to get user_id and stores credentials + 3. Sets request.scope["user"] with SimpleUser instance + 4. Extracts TIP token from X-Ve-TIP-Token header + 5. Exchanges TIP token for workload token via IdentityClient + 6. Sets request.scope["auth"] with WorkloadToken object + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + The HTTP response + """ + from starlette.authentication import SimpleUser + + token, has_prefix = self._extract_token(request) + user_id = None + + if token: + user_id, _ = extract_delegation_chain_from_jwt(token) + if user_id: + # Build auth config based on authentication method + auth_config = build_auth_config( + token=token, + auth_method=self.auth_method, + credential_key=self.credential_key, + header_scheme="bearer" if has_prefix else None, + query_param_name=self.token_param, + ) + + await self.credential_service.set_credential( + app_name=self.app_name, + user_id=user_id, + credential_key=self.credential_key, + credential=auth_config.exchanged_auth_credential, + ) + + logger.debug( + f"Stored credential for app={self.app_name}, user={user_id}, " + f"method={self.auth_method}" + ) + + request.scope["user"] = SimpleUser(user_id) + else: + logger.warning("Failed to extract user_id from JWT token") + + # Extract TIP token from X-Ve-TIP-Token header for trust propagation + tip_token = request.headers.get(VE_TIP_TOKEN_HEADER) + + try: + workload_token: WorkloadToken = ( + self.identity_client.get_workload_access_token( + user_token=tip_token, user_id=user_id + ) + ) + workload_auth_config = build_auth_config( + token=workload_token.workload_access_token, + auth_method="apikey", + credential_key=VE_TIP_TOKEN_CREDENTIAL_KEY, + header_name=VE_TIP_TOKEN_HEADER, + ) + + await self.credential_service.set_credential( + app_name=self.app_name, + user_id=user_id, + credential_key=VE_TIP_TOKEN_CREDENTIAL_KEY, + credential=workload_auth_config.exchanged_auth_credential, + ) + except ApiException as e: + logger.warning(f"Failed to get workload token: {e.reason}") + workload_token = None + request.scope["auth"] = workload_token + # Continue processing the request + response = await call_next(request) + return response + + def _extract_token(self, request: Request) -> tuple[Optional[str], bool]: + """Extract authentication token from the request. + + Args: + request: The HTTP request + + Returns: + The extracted token, or None if not found + """ + has_prefix = False + token = None + + if self.auth_method == "header": + # Extract from Authorization header + auth_header = request.headers.get("Authorization") or request.headers.get( + "authorization" + ) + if auth_header: + # Strip "Bearer " prefix if present + if auth_header.lower().startswith("bearer "): + has_prefix = True + token = auth_header[7:] + else: + token = auth_header + else: + token = None + elif self.auth_method == "querystring": + # Extract from query string + token = request.query_params.get(self.token_param) + + return token, has_prefix + + +def build_a2a_auth_middleware( + app_name: str, + credential_service: VeCredentialService, + auth_method: Literal["header", "querystring"] = "header", + token_param: str = "token", + credential_key: str = "inbound_auth", + identity_client: Optional[IdentityClient] = None, +) -> type[A2AAuthMiddleware]: + """Build an A2A authentication middleware class. + + This is a factory function that creates a middleware class with the + specified configuration. Use this with FastAPI's add_middleware(). + + The middleware extracts authentication tokens from incoming requests, + parses JWT delegation chains, stores credentials, and sets user information + in the request state for downstream handlers. + + TIP Token Support: + The middleware will: + 1. Extract TIP token from X-Ve-TIP-Token header + 2. Exchange TIP token for workload token using IdentityClient + 3. Set WorkloadToken object in request.scope["auth"] for downstream use + + If identity_client is not provided, uses the global IdentityClient + from VeIdentityConfig. + + Args: + app_name: Application name for credential storage + credential_service: Credential service to store credentials + auth_method: Authentication method - "header" or "querystring" + token_param: Query parameter name for token (when auth_method="querystring") + credential_key: Key to identify the credential in the store + identity_client: Optional IdentityClient for TIP token exchange. + If not provided, uses the global IdentityClient from VeIdentityConfig. + + Returns: + A configured middleware class + + Request Attributes: + After successful authentication, the following attributes are set: + - request.scope["user"]: SimpleUser instance with the user_id from JWT + - request.scope["auth"]: WorkloadToken object containing workload_access_token + + Examples: + ```python + from fastapi import FastAPI, Request + from veadk.a2a.ve_middlewares import build_a2a_auth_middleware + from veadk.auth.ve_credential_service import VeCredentialService + from veadk.integrations.ve_identity import IdentityClient + + app = FastAPI() + credential_service = VeCredentialService() + + # Optional: Create identity client for TIP token support + # If not provided, uses global client from VeIdentityConfig + identity_client = IdentityClient(region="cn-beijing") + + # Add middleware with TIP token support + app.add_middleware( + build_a2a_auth_middleware( + app_name="my_app", + credential_service=credential_service, + auth_method="header", + identity_client=identity_client, # Optional, uses global if not provided + ) + ) + ``` + """ + + class ConfiguredA2AAuthMiddleware(A2AAuthMiddleware): + def __init__(self, app): + super().__init__( + app=app, + app_name=app_name, + credential_service=credential_service, + auth_method=auth_method, + token_param=token_param, + credential_key=credential_key, + identity_client=identity_client, + ) + + return ConfiguredA2AAuthMiddleware diff --git a/veadk/auth/ve_credential_service.py b/veadk/auth/ve_credential_service.py new file mode 100644 index 00000000..b29cb425 --- /dev/null +++ b/veadk/auth/ve_credential_service.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A Credential Service for VeADK. + +This module provides a credential service that supports direct credential +management by app_name and user_id, extending the ADK BaseCredentialService. +""" + +import logging +from typing import Optional + +from typing_extensions import override + +from google.adk.agents.callback_context import CallbackContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.base_credential_service import ( + BaseCredentialService, +) + +logger = logging.getLogger(__name__) + + +class VeCredentialService(BaseCredentialService): + """In-memory credential service with direct app_name/user_id access. + + This service extends BaseCredentialService to support both: + 1. Standard ADK credential operations (load_credential, save_credential) + 2. Direct credential access by app_name and user_id + + The credential store is organized as: + { + app_name: { + user_id: { + credential_key: AuthCredential + } + } + } + + Examples: + ```python + # Create service + service = VeCredentialService() + + # Direct set/get + await service.set_credential( + app_name="my_app", + user_id="user123", + credential_key="bearer_token", + credential=AuthCredential(...) + ) + + credential = await service.get_credential( + app_name="my_app", + user_id="user123", + credential_key="bearer_token" + ) + + # Standard ADK operations + await service.save_credential(auth_config, callback_context) + credential = await service.load_credential(auth_config, callback_context) + ``` + """ + + def __init__(self): + """Initialize the credential service with empty storage.""" + super().__init__() + self._credentials: dict[str, dict[str, dict[str, AuthCredential]]] = {} + + @override + async def load_credential( + self, + auth_config: AuthConfig, + callback_context: CallbackContext, + ) -> Optional[AuthCredential]: + """Load credential from the store using auth config and callback context. + + Args: + auth_config: Auth configuration containing credential_key + callback_context: Callback context containing app_name and user_id + + Returns: + The stored AuthCredential, or None if not found + """ + app_name = callback_context._invocation_context.app_name + user_id = callback_context._invocation_context.user_id + + return await self.get_credential( + app_name=app_name, + user_id=user_id, + credential_key=auth_config.credential_key, + ) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + callback_context: CallbackContext, + ) -> None: + """Save credential to the store using auth config and callback context. + + Args: + auth_config: Auth configuration containing credential_key and exchanged_auth_credential + callback_context: Callback context containing app_name and user_id + """ + app_name = callback_context._invocation_context.app_name + user_id = callback_context._invocation_context.user_id + + await self.set_credential( + app_name=app_name, + user_id=user_id, + credential_key=auth_config.credential_key, + credential=auth_config.exchanged_auth_credential, + ) + + async def set_credential( + self, + app_name: str, + user_id: str, + credential_key: str, + credential: AuthCredential, + ) -> None: + """Directly set a credential by app_name, user_id, and credential_key (async). + + This method allows setting credentials without requiring a CallbackContext, + useful for middleware and request interceptors. + + Args: + app_name: Application name + user_id: User identifier + credential_key: Key to identify the credential + credential: The AuthCredential to store + + Examples: + ```python + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes + + service = VeCredentialService() + await service.set_credential( + app_name="my_app", + user_id="user123", + credential_key="bearer_token", + credential=AuthCredential( + auth_type=AuthCredentialTypes.BEARER_TOKEN, + bearer_token="eyJhbGc..." + ) + ) + ``` + """ + if app_name not in self._credentials: + self._credentials[app_name] = {} + if user_id not in self._credentials[app_name]: + self._credentials[app_name][user_id] = {} + + self._credentials[app_name][user_id][credential_key] = credential + logger.debug( + f"Set credential for app={app_name}, user={user_id}, key={credential_key}" + ) + + async def get_credential( + self, + app_name: str, + user_id: str, + credential_key: str, + ) -> Optional[AuthCredential]: + """Directly get a credential by app_name, user_id, and credential_key (async). + + This method allows retrieving credentials without requiring a CallbackContext, + useful for middleware and request interceptors. + + Args: + app_name: Application name + user_id: User identifier + credential_key: Key to identify the credential + + Returns: + The stored AuthCredential, or None if not found + + Examples: + ```python + service = VeCredentialService() + credential = await service.get_credential( + app_name="my_app", + user_id="user123", + credential_key="bearer_token" + ) + if credential: + print(f"Found token: {credential.bearer_token}") + ``` + """ + return self._credentials.get(app_name, {}).get(user_id, {}).get(credential_key) diff --git a/veadk/configs/auth_configs.py b/veadk/configs/auth_configs.py index fbbd81ee..4d7ae9bc 100644 --- a/veadk/configs/auth_configs.py +++ b/veadk/configs/auth_configs.py @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Optional, Any from pydantic_settings import BaseSettings, SettingsConfigDict +if TYPE_CHECKING: + from veadk.integrations.ve_identity.identity_client import IdentityClient +else: + # For runtime, use Any to avoid circular import issues + IdentityClient = Any + class VeIdentityConfig(BaseSettings): """Configuration for VolcEngine Identity Service. @@ -21,6 +28,11 @@ class VeIdentityConfig(BaseSettings): This configuration class manages settings for Agent Identity service, including region and endpoint information. + It also provides a global singleton IdentityClient instance to ensure: + - Credential caching is shared across the application + - HTTP connection pooling is reused + - Consistent configuration throughout the application + Attributes: region: The VolcEngine region for Identity service. endpoint: The endpoint URL for Identity service API. @@ -35,7 +47,7 @@ class VeIdentityConfig(BaseSettings): endpoint: str = "" """The endpoint URL for Identity service API. - + If not provided, the endpoint will be auto-generated based on the region. """ @@ -43,6 +55,9 @@ class VeIdentityConfig(BaseSettings): """Role session name, used to distinguish different sessions in audit logs. """ + # Global singleton IdentityClient instance + _identity_client: Optional["IdentityClient"] = None + def get_endpoint(self) -> str: """Get the endpoint URL for Identity service. @@ -59,3 +74,60 @@ def get_endpoint(self) -> str: return self.endpoint return f"id.{self.region}.volces.com" + + def get_identity_client( + self, + access_key: Optional[str] = None, + secret_key: Optional[str] = None, + session_token: Optional[str] = None, + ) -> IdentityClient: + """Get or create the global IdentityClient instance. + + This method implements a singleton pattern to ensure only one IdentityClient + instance is created per configuration. This allows: + - Credential caching to be shared across the application + - HTTP connection pooling to be reused + - Consistent configuration throughout the application + + Args: + access_key: Optional VolcEngine access key. If not provided, uses env vars. + secret_key: Optional VolcEngine secret key. If not provided, uses env vars. + session_token: Optional VolcEngine session token. If not provided, uses env vars. + + Returns: + The global IdentityClient instance. + + Examples: + ```python + from veadk.config import settings + + # Get the global identity client + identity_client = settings.veidentity.get_identity_client() + + # Use it to get workload tokens + token = identity_client.get_workload_access_token( + workload_name="my_workload", + user_id="user123" + ) + ``` + """ + # Lazy initialization: create client only when first requested + if self._identity_client is None: + from veadk.integrations.ve_identity import IdentityClient + + self._identity_client = IdentityClient( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + region=self.region, + ) + + return self._identity_client + + def reset_identity_client(self) -> None: + """Reset the global IdentityClient instance. + + This forces the next call to get_identity_client() to create a new instance. + Useful for testing or when credentials need to be refreshed. + """ + self._identity_client = None diff --git a/veadk/integrations/ve_identity/__init__.py b/veadk/integrations/ve_identity/__init__.py index 3a971f85..4f441318 100644 --- a/veadk/integrations/ve_identity/__init__.py +++ b/veadk/integrations/ve_identity/__init__.py @@ -61,6 +61,7 @@ async def get_github_repos(access_token: str): OAuth2AuthConfig, WorkloadAuthConfig, VeIdentityAuthConfig, + get_default_identity_client, ) from veadk.integrations.ve_identity.function_tool import VeIdentityFunctionTool from veadk.integrations.ve_identity.mcp_tool import VeIdentityMcpTool @@ -104,4 +105,6 @@ async def get_github_repos(access_token: str): "WorkloadToken", "OAuth2AuthPoller", "MockOauth2AuthPoller", + # Utils + "get_default_identity_client", ] diff --git a/veadk/integrations/ve_identity/auth_config.py b/veadk/integrations/ve_identity/auth_config.py index 6ce68acc..1d14d7a6 100644 --- a/veadk/integrations/ve_identity/auth_config.py +++ b/veadk/integrations/ve_identity/auth_config.py @@ -42,6 +42,21 @@ def _get_default_region() -> str: return "cn-beijing" +def get_default_identity_client(region: Optional[str] = None) -> IdentityClient: + """Get the default IdentityClient from VeADK configuration. + + Returns: + The configured IdentityClient from VeIdentityConfig, or a new instance as fallback. + """ + try: + from veadk.config import settings + + return settings.veidentity.get_identity_client() + except Exception: + # Fallback to new instance if config loading fails + return IdentityClient(region=region or _get_default_region()) + + class AuthConfig(BaseModel, ABC): """Base authentication configuration.""" diff --git a/veadk/integrations/ve_identity/auth_mixins.py b/veadk/integrations/ve_identity/auth_mixins.py index ca8d3331..3e5095cb 100644 --- a/veadk/integrations/ve_identity/auth_mixins.py +++ b/veadk/integrations/ve_identity/auth_mixins.py @@ -44,7 +44,7 @@ ApiKeyAuthConfig, OAuth2AuthConfig, WorkloadAuthConfig, - _get_default_region, + get_default_identity_client, ) from veadk.integrations.ve_identity.token_manager import get_workload_token @@ -108,11 +108,7 @@ def __init__( # call it without arguments super().__init__() - # Use provided region or get from config - if region is None: - region = _get_default_region() - - self._identity_client = identity_client or IdentityClient(region=region) + self._identity_client = identity_client or get_default_identity_client(region) self._provider_name = provider_name @abstractmethod diff --git a/veadk/integrations/ve_identity/auth_processor.py b/veadk/integrations/ve_identity/auth_processor.py index 096a10d2..5f981b68 100644 --- a/veadk/integrations/ve_identity/auth_processor.py +++ b/veadk/integrations/ve_identity/auth_processor.py @@ -25,9 +25,8 @@ from google.genai import types from google.adk.auth.auth_credential import OAuth2Auth -from veadk.integrations.ve_identity.auth_config import _get_default_region +from veadk.integrations.ve_identity.auth_config import get_default_identity_client from veadk.processors.base_run_processor import BaseRunProcessor -from veadk.integrations.ve_identity.identity_client import IdentityClient from veadk.integrations.ve_identity.models import AuthRequestConfig, OAuth2AuthPoller from veadk.integrations.ve_identity.utils import ( get_function_call_auth_config, @@ -179,12 +178,10 @@ def __init__(self, *, config: Optional[AuthRequestConfig] = None): f"Please open this URL in your browser to authorize: {url}" ) ) - # Use provided region or get from config - if self.config.region is None: - self.config.region = _get_default_region() - self._identity_client = self.config.identity_client or IdentityClient( - region=self.config.region + self._identity_client = ( + self.config.identity_client + or get_default_identity_client(self.config.region) ) async def process_auth_request( diff --git a/veadk/integrations/ve_identity/identity_client.py b/veadk/integrations/ve_identity/identity_client.py index cd9d8e2d..16c48009 100644 --- a/veadk/integrations/ve_identity/identity_client.py +++ b/veadk/integrations/ve_identity/identity_client.py @@ -181,6 +181,7 @@ def __init__( configuration.ak = self._initial_access_key configuration.sk = self._initial_secret_key configuration.session_token = self._initial_session_token + configuration.logger = {} self._api_client = volcenginesdkid.IDApi( volcenginesdkcore.ApiClient(configuration) @@ -260,6 +261,7 @@ def _assume_role( sts_config.region = self.region sts_config.ak = access_key sts_config.sk = secret_key + sts_config.logger = {} # Create an STS API client sts_client = volcenginesdksts.STSApi(volcenginesdkcore.ApiClient(sts_config)) @@ -712,14 +714,16 @@ def check_permission( principal: Dict[str, str], operation: Dict[str, str], resource: Dict[str, str], + original_callers: Optional[List[Dict[str, str]]] = None, namespace: str = "default", ) -> bool: """Check if the principal has permission to perform the operation on the resource. Args: - principal: Principal information, e.g., {"Type": "User", "Id": "user123"} - operation: Operation to check, e.g., {"Type": "Action", "Id": "invoke"} - resource: Resource information, e.g., {"Type": "Agent", "Id": "agent456"} + principal: Principal information, e.g., {"Type": "user", "Id": "user123"} + operation: Operation to check, e.g., {"Type": "action", "Id": "invoke"} + resource: Resource information, e.g., {"Type": "agent", "Id": "agent456"} + original_callers: Optional list of original callers. namespace: Namespace of the resource. Defaults to "default". Returns: @@ -738,6 +742,7 @@ def check_permission( operation=operation, principal=principal, resource=resource, + original_callers=original_callers, ) response: volcenginesdkid.CheckPermissionResponse = ( diff --git a/veadk/integrations/ve_identity/token_manager.py b/veadk/integrations/ve_identity/token_manager.py index 3d719932..6c54d406 100644 --- a/veadk/integrations/ve_identity/token_manager.py +++ b/veadk/integrations/ve_identity/token_manager.py @@ -20,9 +20,12 @@ from typing import Optional, Union from google.adk.tools.tool_context import ToolContext +from google.adk.agents.callback_context import CallbackContext from google.adk.agents.readonly_context import ReadonlyContext -from veadk.integrations.ve_identity.auth_config import _get_default_region +from veadk.integrations.ve_identity.auth_config import ( + get_default_identity_client, +) from veadk.utils.logger import get_logger from veadk.integrations.ve_identity.identity_client import IdentityClient @@ -49,21 +52,25 @@ class WorkloadTokenManager: def __init__( self, - identity_client: IdentityClient = None, + identity_client: Optional[IdentityClient] = None, region: Optional[str] = None, ): """Initialize the token manager. Args: - identity_client: The IdentityClient instance to use for token requests. + identity_client: Optional IdentityClient instance to use for token requests. + If not provided and use_global_client is True, uses the global client + from VeIdentityConfig. + region: Optional region for creating a new IdentityClient. + Only used if identity_client is not provided and use_global_client is False. """ - if region is None: - region = _get_default_region() - self._identity_client = identity_client or IdentityClient(region=region) + self._identity_client = identity_client or get_default_identity_client( + region=region + ) def _build_cache_key( - self, tool_context: Union[ToolContext | ReadonlyContext] + self, tool_context: Union[ToolContext | CallbackContext | ReadonlyContext] ) -> str: """Build a unique cache key for storing the workload token. @@ -93,7 +100,7 @@ def _is_token_expired(self, expires_at: int) -> bool: async def get_workload_token( self, - tool_context: Union[ToolContext | ReadonlyContext], + tool_context: Union[ToolContext | CallbackContext | ReadonlyContext], workload_name: Optional[str] = None, user_token: Optional[str] = None, ) -> str: @@ -148,7 +155,7 @@ async def get_workload_token( async def get_workload_token( - tool_context: Union[ToolContext | ReadonlyContext], + tool_context: Union[ToolContext | CallbackContext | ReadonlyContext], identity_client: Optional[IdentityClient] = None, workload_name: Optional[str] = None, user_token: Optional[str] = None, diff --git a/veadk/runner.py b/veadk/runner.py index 369011dd..de3d03a0 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -345,7 +345,7 @@ def __init__( self.long_term_memory = None self.short_term_memory = short_term_memory self.upload_inline_data_to_tos = upload_inline_data_to_tos - + credential_service = kwargs.pop("credential_service", None) session_service = kwargs.pop("session_service", None) memory_service = kwargs.pop("memory_service", None) @@ -397,6 +397,7 @@ def __init__( agent=agent, session_service=session_service, memory_service=memory_service, + credential_service=credential_service, app_name=app_name, *args, **kwargs, diff --git a/veadk/tools/builtin_tools/agent_authorization.py b/veadk/tools/builtin_tools/agent_authorization.py index 850e7afa..dc5db6c5 100644 --- a/veadk/tools/builtin_tools/agent_authorization.py +++ b/veadk/tools/builtin_tools/agent_authorization.py @@ -12,104 +12,78 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 -import json from typing import Optional from google.genai import types from google.adk.agents.callback_context import CallbackContext -from veadk.integrations.ve_identity.auth_config import _get_default_region -from veadk.integrations.ve_identity.identity_client import IdentityClient -from veadk.integrations.ve_identity.token_manager import get_workload_token +from veadk.integrations.ve_identity import ( + get_default_identity_client, + get_workload_token, +) from veadk.utils.logger import get_logger +from veadk.utils.auth import extract_delegation_chain_from_jwt logger = get_logger(__name__) - -region = _get_default_region() -identity_client = IdentityClient(region=region) - - -def _strip_bearer_prefix(token: str) -> str: - """Remove 'Bearer ' prefix from token if present. - Args: - token: Token string that may contain "Bearer " prefix - Returns: - Token without "Bearer " prefix - """ - return token[7:] if token.lower().startswith("bearer ") else token - - -def _extract_role_id_from_jwt(token: str) -> Optional[str]: - """Extract role_id (sub field) from JWT token. - Args: - token: JWT token string (with or without "Bearer " prefix) - Returns: - Role ID from sub field, or None if parsing fails - """ - try: - # Remove "Bearer " prefix if present - token = _strip_bearer_prefix(token) - - # JWT token has 3 parts separated by dots: header.payload.signature - parts = token.split(".") - if len(parts) != 3: - logger.error("Invalid JWT format: expected 3 parts") - return None - - # Decode payload (second part) - payload_part = parts[1] - - # Add padding for base64url decoding (JWT doesn't use padding) - missing_padding = len(payload_part) % 4 - if missing_padding: - payload_part += "=" * (4 - missing_padding) - - # Decode base64 and parse JSON - decoded_bytes = base64.urlsafe_b64decode(payload_part) - payload = json.loads(decoded_bytes.decode("utf-8")) - - # Extract sub field as role_id - return payload.get("act").get("sub") - - except (ValueError, json.JSONDecodeError) as e: - logger.error(f"Failed to parse JWT token: {e}") - return None - except Exception as e: - logger.error(f"Unexpected error parsing JWT: {e}") - return None +identity_client = get_default_identity_client() async def check_agent_authorization( callback_context: CallbackContext, ) -> Optional[types.Content]: - """Check if the agent is authorized to run using VeIdentity.""" - user_id = callback_context._invocation_context.user_id - + """Check if the agent is authorized to run using Agent Identity.""" try: workload_token = await get_workload_token( tool_context=callback_context, identity_client=identity_client ) - # Parse role_id from workload_token - role_id = _extract_role_id_from_jwt(workload_token) + # Parse user_id and actors from workload_token + user_id, actors = extract_delegation_chain_from_jwt(workload_token) - principal = {"Type": "User", "Id": user_id} - operation = {"Type": "Action", "Id": "invoke"} - resource = {"Type": "Agent", "Id": role_id} + if not user_id: + logger.warning("Failed to extract user_id from JWT token") + return types.Content( + parts=[types.Part(text="Failed to verify agent authorization.")], + role="model", + ) + + if len(actors) == 0: + logger.warning("Failed to extract actors from JWT token") + return types.Content( + parts=[types.Part(text="Failed to verify agent authorization.")], + role="model", + ) + + # The first actor in the chain is the agent itself + role_id = actors[0] + + principal = {"Type": "user", "Id": user_id} + operation = {"Type": "action", "Id": "invoke"} + resource = {"Type": "agent", "Id": role_id} + original_callers = [{"Type": "agent", "Id": actor} for actor in actors[1:]] allowed = identity_client.check_permission( - principal=principal, operation=operation, resource=resource + principal=principal, + operation=operation, + resource=resource, + original_callers=original_callers, ) if allowed: - logger.info("Agent is authorized to run.") + logger.info(f"Agent {role_id} is authorized to run by user {user_id}.") return None else: - logger.warning("Agent is not authorized to run.") + logger.warning( + f"Agent {role_id} is not authorized to run by user {user_id}." + ) return types.Content( - parts=[types.Part(text="Agent is not authorized to run.")], role="model" + parts=[ + types.Part( + text=f"Agent {role_id} is not authorized to run by user {user_id}." + ) + ], + role="model", ) except Exception as e: diff --git a/veadk/utils/auth.py b/veadk/utils/auth.py new file mode 100644 index 00000000..565a26b7 --- /dev/null +++ b/veadk/utils/auth.py @@ -0,0 +1,294 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Authentication utilities for VeADK. + +This module provides utilities for authentication, including: +- JWT token parsing and delegation chain extraction +- AuthConfig building for various authentication methods +- TIP (Trust Identity Propagation) token management +- Workload token caching for middleware +""" + +import base64 +import json +from typing import Literal, Optional + +from google.adk.auth.auth_credential import ( + AuthCredential, + AuthCredentialTypes, + HttpAuth, + HttpCredentials, +) +from fastapi.openapi.models import APIKey, HTTPBearer, APIKeyIn +from google.adk.auth.auth_tool import AuthConfig + +# TIP Token Header is used for Trust Identity Propagation (TIP) tokens +VE_TIP_TOKEN_HEADER = "X-Ve-TIP-Token" +VE_TIP_TOKEN_CREDENTIAL_KEY = "ve_tip_token" + + +def strip_bearer_prefix(token: str) -> str: + """Remove 'Bearer ' prefix from token if present. + + Args: + token: Token string that may contain "Bearer " prefix + + Returns: + Token without "Bearer " prefix + """ + return token[7:] if token.lower().startswith("bearer ") else token + + +def extract_delegation_chain_from_jwt(token: str) -> tuple[Optional[str], list[str]]: + """Extract subject and delegation chain from JWT token. + + Parses JWT tokens containing delegation information per RFC 8693. + Returns the primary subject and the chain of actors who acted on behalf. + + Args: + token: JWT token string (with or without "Bearer " prefix) + + Returns: + A tuple of (subject, actors) where: + - subject: The end user or resource owner (from `sub` field) + - actors: Chain of intermediaries who acted on behalf (from nested `act` claims) + + Examples: + ```python + # User → Agent1 → Agent2 + subject, actors = extract_delegation_chain_from_jwt(token) + # Returns: ("user1", ["agent2", "agent1"]) + # Meaning: user1 delegated to agent1, who delegated to agent2 + ``` + """ + try: + # Remove "Bearer " prefix if present + token = strip_bearer_prefix(token) + + # JWT token has 3 parts separated by dots: header.payload.signature + parts = token.split(".") + if len(parts) != 3: + return None, [] + + # Decode payload (second part) + payload_part = parts[1] + + # Add padding for base64url decoding (JWT doesn't use padding) + missing_padding = len(payload_part) % 4 + if missing_padding: + payload_part += "=" * (4 - missing_padding) + + # Decode base64 and parse JSON + decoded_bytes = base64.urlsafe_b64decode(payload_part) + payload: dict = json.loads(decoded_bytes.decode("utf-8")) + + # Extract subject from JWT + subject = payload.get("sub") + if not subject: + return None, [] + + # Extract actor chain from nested "act" claims + actors = [] + current_act = payload.get("act") + while current_act and isinstance(current_act, dict): + actor_sub = current_act.get("sub") + if actor_sub: + actors.append(str(actor_sub)) + # Move to next level in the chain + current_act = current_act.get("act") + + return str(subject), actors + + except (ValueError, json.JSONDecodeError, Exception): + return None, [] + + +def _build_http_bearer_auth( + token: Optional[str], scheme: str +) -> tuple[HTTPBearer, Optional[AuthCredential]]: + """Build HTTP Bearer authentication scheme and credential. + + Args: + token: The authentication token + scheme: HTTP authentication scheme (e.g., "bearer", "basic") + + Returns: + Tuple of (auth_scheme, auth_credential) + """ + auth_scheme = HTTPBearer() + + auth_credential = None + if token: + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme=scheme.lower(), + credentials=HttpCredentials(token=token), + ), + ) + + return auth_scheme, auth_credential + + +def _build_api_key_header_auth( + token: Optional[str], header_name: str +) -> tuple[APIKey, Optional[AuthCredential]]: + """Build API Key in header authentication scheme and credential. + + Args: + token: The authentication token + header_name: Name of the HTTP header + + Returns: + Tuple of (auth_scheme, auth_credential) + """ + auth_scheme = APIKey(**{"in": APIKeyIn.header, "name": header_name}) + + auth_credential = None + if token: + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key=token, + ) + + return auth_scheme, auth_credential + + +def _build_api_key_query_auth( + token: Optional[str], query_param_name: str +) -> tuple[APIKey, Optional[AuthCredential]]: + """Build API Key in query string authentication scheme and credential. + + Args: + token: The authentication token + query_param_name: Name of the query parameter + + Returns: + Tuple of (auth_scheme, auth_credential) + """ + auth_scheme = APIKey(**{"in": APIKeyIn.query, "name": query_param_name}) + + auth_credential = None + if token: + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key=token, + ) + + return auth_scheme, auth_credential + + +def build_auth_config( + *, + credential_key: str, + token: Optional[str] = None, + auth_method: Literal["header", "querystring", "bearer", "apikey"] = "header", + # Header-specific options + header_name: str = "Authorization", + header_scheme: Optional[str] = None, + # Query string-specific options + query_param_name: str = "token", +) -> AuthConfig: + """Build AuthConfig for various authentication methods. + + This is a general-purpose utility function for constructing AuthConfig objects + that can be used with ADK's credential service. It supports multiple authentication + methods and provides flexible configuration options. + + Args: + token: The authentication token/credential value. If None, only the auth scheme + will be configured without credentials. + auth_method: The authentication method to use: + - "header": Generic header-based authentication (API Key in header) + - "bearer": HTTP Bearer token authentication (Authorization: Bearer ) + - "querystring" or "apikey": API Key in query string parameter + credential_key: Key to identify this credential in the credential service. + Default is "inbound_auth". + header_name: Name of the HTTP header for header-based auth. Default is "Authorization". + header_scheme: HTTP authentication scheme (e.g., "bearer", "basic"). If provided, + uses HTTP auth; otherwise uses API Key auth for headers. + query_param_name: Name of the query parameter for query string auth. Default is "token". + + Returns: + AuthConfig object with the appropriate auth scheme and credential. + + Raises: + ValueError: If an unsupported auth_method is provided. + + Examples: + ```python + # Example 1: HTTP Bearer token + config = build_auth_config( + token="eyJhbGc...", + auth_method="bearer", + credential_key="my_auth" + ) + + # Example 2: API Key in header + config = build_auth_config( + token="sk-1234567890", + auth_method="header", + header_name="X-API-Key", + credential_key="api_key_auth" + ) + + # Example 3: API Key in query string + config = build_auth_config( + token="abc123", + auth_method="querystring", + query_param_name="api_key", + credential_key="query_auth" + ) + + # Example 4: Only auth scheme (no credential) + config = build_auth_config( + auth_method="bearer", + credential_key="bearer_auth" + ) + # Returns AuthConfig with scheme but no credential + ``` + """ + # Determine which builder function to use based on auth_method + if auth_method == "bearer": + # Bearer is a special case of HTTP auth with bearer scheme + auth_scheme, auth_credential = _build_http_bearer_auth(token, "bearer") + + elif auth_method == "header": + if header_scheme: + # HTTP authentication (e.g., Bearer, Basic) + auth_scheme, auth_credential = _build_http_bearer_auth(token, header_scheme) + else: + # API Key in header + auth_scheme, auth_credential = _build_api_key_header_auth( + token, header_name + ) + + elif auth_method in ("querystring", "apikey"): + # API Key in query string + auth_scheme, auth_credential = _build_api_key_query_auth( + token, query_param_name + ) + + else: + raise ValueError( + f"Unsupported auth_method: {auth_method}. " + f"Supported methods: 'header', 'bearer', 'querystring', 'apikey'" + ) + + return AuthConfig( + auth_scheme=auth_scheme, + exchanged_auth_credential=auth_credential, + credential_key=credential_key, + )