Skip to content

Commit f68e2d0

Browse files
committed
feat(identity): add unit test for middleware and credential service
1 parent 5fd0c92 commit f68e2d0

File tree

3 files changed

+683
-0
lines changed

3 files changed

+683
-0
lines changed

tests/a2a/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A2A tests package."""
16+
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for A2A authentication middleware."""
16+
17+
import pytest
18+
from unittest.mock import Mock, AsyncMock, patch
19+
from starlette.applications import Starlette
20+
from starlette.requests import Request
21+
from starlette.responses import Response
22+
from starlette.testclient import TestClient
23+
24+
from veadk.a2a.ve_middlewares import A2AAuthMiddleware, build_a2a_auth_middleware
25+
from veadk.auth.credential_service import VeCredentialService
26+
from veadk.utils.auth import VE_TIP_TOKEN_HEADER
27+
28+
29+
@pytest.fixture
30+
def credential_service():
31+
"""Create a VeCredentialService instance for testing."""
32+
return VeCredentialService()
33+
34+
35+
@pytest.fixture
36+
def mock_identity_client():
37+
"""Create a mock IdentityClient."""
38+
mock_client = Mock()
39+
mock_client.get_workload_access_token = Mock()
40+
return mock_client
41+
42+
43+
@pytest.fixture
44+
def sample_jwt_token():
45+
"""Sample JWT token for testing."""
46+
# This is a sample JWT with sub="user123" and act claim
47+
# Header: {"alg": "HS256", "typ": "JWT"}
48+
# Payload: {"sub": "user123", "act": {"sub": "agent1"}}
49+
return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiYWN0Ijp7InN1YiI6ImFnZW50MSJ9fQ.signature"
50+
51+
52+
class TestA2AAuthMiddleware:
53+
"""Tests for A2AAuthMiddleware class."""
54+
55+
def test_middleware_initialization(self, credential_service, mock_identity_client):
56+
"""Test middleware initialization with all parameters."""
57+
app = Starlette()
58+
middleware = A2AAuthMiddleware(
59+
app=app,
60+
app_name="test_app",
61+
credential_service=credential_service,
62+
auth_method="header",
63+
token_param="token",
64+
credential_key="test_key",
65+
identity_client=mock_identity_client,
66+
)
67+
68+
assert middleware.app_name == "test_app"
69+
assert middleware.credential_service == credential_service
70+
assert middleware.auth_method == "header"
71+
assert middleware.token_param == "token"
72+
assert middleware.credential_key == "test_key"
73+
assert middleware.identity_client == mock_identity_client
74+
75+
def test_middleware_default_identity_client(self, credential_service):
76+
"""Test middleware uses global identity client when not provided."""
77+
app = Starlette()
78+
79+
with patch("veadk.a2a.ve_a2a_middlewares.get_default_identity_client") as mock_get_client:
80+
mock_client = Mock()
81+
mock_get_client.return_value = mock_client
82+
83+
middleware = A2AAuthMiddleware(
84+
app=app,
85+
app_name="test_app",
86+
credential_service=credential_service,
87+
)
88+
89+
mock_get_client.assert_called_once()
90+
assert middleware.identity_client == mock_client
91+
92+
def test_extract_token_from_header_with_bearer(self, credential_service):
93+
"""Test extracting token from Authorization header with Bearer prefix."""
94+
app = Starlette()
95+
middleware = A2AAuthMiddleware(
96+
app=app,
97+
app_name="test_app",
98+
credential_service=credential_service,
99+
auth_method="header",
100+
)
101+
102+
# Create mock request
103+
mock_request = Mock(spec=Request)
104+
mock_request.headers = {"Authorization": "Bearer test_token_123"}
105+
106+
token, has_prefix = middleware._extract_token(mock_request)
107+
108+
assert token == "test_token_123"
109+
assert has_prefix is True
110+
111+
def test_extract_token_from_header_without_bearer(self, credential_service):
112+
"""Test extracting token from Authorization header without Bearer prefix."""
113+
app = Starlette()
114+
middleware = A2AAuthMiddleware(
115+
app=app,
116+
app_name="test_app",
117+
credential_service=credential_service,
118+
auth_method="header",
119+
)
120+
121+
mock_request = Mock(spec=Request)
122+
mock_request.headers = {"Authorization": "test_token_123"}
123+
124+
token, has_prefix = middleware._extract_token(mock_request)
125+
126+
assert token == "test_token_123"
127+
assert has_prefix is False
128+
129+
def test_extract_token_from_query_string(self, credential_service):
130+
"""Test extracting token from query string."""
131+
app = Starlette()
132+
middleware = A2AAuthMiddleware(
133+
app=app,
134+
app_name="test_app",
135+
credential_service=credential_service,
136+
auth_method="querystring",
137+
token_param="access_token",
138+
)
139+
140+
mock_request = Mock(spec=Request)
141+
mock_request.query_params = {"access_token": "test_token_123"}
142+
143+
token, has_prefix = middleware._extract_token(mock_request)
144+
145+
assert token == "test_token_123"
146+
assert has_prefix is False
147+
148+
def test_extract_token_no_token_found(self, credential_service):
149+
"""Test extracting token when no token is present."""
150+
app = Starlette()
151+
middleware = A2AAuthMiddleware(
152+
app=app,
153+
app_name="test_app",
154+
credential_service=credential_service,
155+
auth_method="header",
156+
)
157+
158+
mock_request = Mock(spec=Request)
159+
mock_request.headers = {}
160+
161+
token, has_prefix = middleware._extract_token(mock_request)
162+
163+
assert token is None
164+
assert has_prefix is False
165+
166+
@pytest.mark.asyncio
167+
async def test_dispatch_with_valid_jwt_token(
168+
self, credential_service, mock_identity_client, sample_jwt_token
169+
):
170+
"""Test dispatch with valid JWT token."""
171+
app = Starlette()
172+
173+
# Mock WorkloadToken
174+
mock_workload_token = Mock()
175+
mock_workload_token.workload_access_token = "workload_token_123"
176+
mock_workload_token.expires_at = 1234567890
177+
mock_identity_client.get_workload_access_token.return_value = mock_workload_token
178+
179+
middleware = A2AAuthMiddleware(
180+
app=app,
181+
app_name="test_app",
182+
credential_service=credential_service,
183+
auth_method="header",
184+
identity_client=mock_identity_client,
185+
)
186+
187+
# Create mock request with JWT token
188+
mock_request = Mock(spec=Request)
189+
mock_request.headers = {
190+
"Authorization": f"Bearer {sample_jwt_token}",
191+
}
192+
mock_request.scope = {}
193+
194+
# Mock call_next
195+
async def mock_call_next(request):
196+
return Response("OK", status_code=200)
197+
198+
# Execute dispatch
199+
with patch("veadk.a2a.ve_a2a_middlewares.extract_delegation_chain_from_jwt") as mock_extract:
200+
mock_extract.return_value = ("user123", ["agent1"])
201+
202+
with patch("veadk.a2a.ve_a2a_middlewares.build_auth_config") as mock_build_config:
203+
mock_auth_config = Mock()
204+
mock_auth_config.exchanged_auth_credential = Mock()
205+
mock_build_config.return_value = mock_auth_config
206+
207+
response = await middleware.dispatch(mock_request, mock_call_next)
208+
209+
# Verify response
210+
assert response.status_code == 200
211+
212+
# Verify user was set in request scope
213+
assert "user" in mock_request.scope
214+
assert mock_request.scope["user"].username == "user123"
215+
216+
# Verify workload token was set
217+
assert "auth" in mock_request.scope
218+
assert mock_request.scope["auth"] == mock_workload_token
219+
220+
@pytest.mark.asyncio
221+
async def test_dispatch_with_tip_token(
222+
self, credential_service, mock_identity_client, sample_jwt_token
223+
):
224+
"""Test dispatch with TIP token in header."""
225+
app = Starlette()
226+
227+
# Mock WorkloadToken
228+
mock_workload_token = Mock()
229+
mock_workload_token.workload_access_token = "workload_token_from_tip"
230+
mock_workload_token.expires_at = 1234567890
231+
mock_identity_client.get_workload_access_token.return_value = mock_workload_token
232+
233+
middleware = A2AAuthMiddleware(
234+
app=app,
235+
app_name="test_app",
236+
credential_service=credential_service,
237+
auth_method="header",
238+
identity_client=mock_identity_client,
239+
)
240+
241+
# Create mock request with both JWT and TIP token
242+
tip_token = "tip_token_123"
243+
mock_request = Mock(spec=Request)
244+
mock_request.headers = {
245+
"Authorization": f"Bearer {sample_jwt_token}",
246+
VE_TIP_TOKEN_HEADER: tip_token,
247+
}
248+
mock_request.scope = {}
249+
250+
# Mock call_next
251+
async def mock_call_next(request):
252+
return Response("OK", status_code=200)
253+
254+
# Execute dispatch
255+
with patch("veadk.a2a.ve_a2a_middlewares.extract_delegation_chain_from_jwt") as mock_extract:
256+
mock_extract.return_value = ("user123", ["agent1"])
257+
258+
with patch("veadk.a2a.ve_a2a_middlewares.build_auth_config") as mock_build_config:
259+
mock_auth_config = Mock()
260+
mock_auth_config.exchanged_auth_credential = Mock()
261+
mock_build_config.return_value = mock_auth_config
262+
263+
response = await middleware.dispatch(mock_request, mock_call_next)
264+
265+
# Verify TIP token was used for workload token exchange
266+
mock_identity_client.get_workload_access_token.assert_called_once_with(
267+
user_token=tip_token,
268+
user_id="user123"
269+
)
270+
271+
# Verify workload token was set
272+
assert mock_request.scope["auth"] == mock_workload_token
273+
274+
275+
class TestBuildA2AAuthMiddleware:
276+
"""Tests for build_a2a_auth_middleware factory function."""
277+
278+
def test_build_middleware_basic(self, credential_service):
279+
"""Test building middleware with basic parameters."""
280+
middleware_class = build_a2a_auth_middleware(
281+
app_name="test_app",
282+
credential_service=credential_service,
283+
)
284+
285+
assert middleware_class is not None
286+
assert issubclass(middleware_class, A2AAuthMiddleware)
287+
288+
def test_build_middleware_with_all_params(self, credential_service, mock_identity_client):
289+
"""Test building middleware with all parameters."""
290+
middleware_class = build_a2a_auth_middleware(
291+
app_name="test_app",
292+
credential_service=credential_service,
293+
auth_method="querystring",
294+
token_param="access_token",
295+
credential_key="custom_key",
296+
identity_client=mock_identity_client,
297+
)
298+
299+
# Create instance to verify parameters
300+
app = Starlette()
301+
instance = middleware_class(app)
302+
303+
assert instance.app_name == "test_app"
304+
assert instance.auth_method == "querystring"
305+
assert instance.token_param == "access_token"
306+
assert instance.credential_key == "custom_key"
307+
assert instance.identity_client == mock_identity_client
308+

0 commit comments

Comments
 (0)