1515"""Unit tests for A2A authentication middleware."""
1616
1717import pytest
18- from unittest .mock import Mock , AsyncMock , patch
18+ from unittest .mock import Mock , patch
1919from starlette .applications import Starlette
2020from starlette .requests import Request
2121from starlette .responses import Response
22- from starlette .testclient import TestClient
2322
2423from veadk .a2a .ve_middlewares import A2AAuthMiddleware , build_a2a_auth_middleware
2524from veadk .auth .credential_service import VeCredentialService
@@ -75,17 +74,19 @@ def test_middleware_initialization(self, credential_service, mock_identity_clien
7574 def test_middleware_default_identity_client (self , credential_service ):
7675 """Test middleware uses global identity client when not provided."""
7776 app = Starlette ()
78-
79- with patch ("veadk.a2a.ve_a2a_middlewares.get_default_identity_client" ) as mock_get_client :
77+
78+ with patch (
79+ "veadk.a2a.ve_middlewares.get_default_identity_client"
80+ ) as mock_get_client :
8081 mock_client = Mock ()
8182 mock_get_client .return_value = mock_client
82-
83+
8384 middleware = A2AAuthMiddleware (
8485 app = app ,
8586 app_name = "test_app" ,
8687 credential_service = credential_service ,
8788 )
88-
89+
8990 mock_get_client .assert_called_once ()
9091 assert middleware .identity_client == mock_client
9192
@@ -104,7 +105,7 @@ def test_extract_token_from_header_with_bearer(self, credential_service):
104105 mock_request .headers = {"Authorization" : "Bearer test_token_123" }
105106
106107 token , has_prefix = middleware ._extract_token (mock_request )
107-
108+
108109 assert token == "test_token_123"
109110 assert has_prefix is True
110111
@@ -122,7 +123,7 @@ def test_extract_token_from_header_without_bearer(self, credential_service):
122123 mock_request .headers = {"Authorization" : "test_token_123" }
123124
124125 token , has_prefix = middleware ._extract_token (mock_request )
125-
126+
126127 assert token == "test_token_123"
127128 assert has_prefix is False
128129
@@ -174,7 +175,9 @@ async def test_dispatch_with_valid_jwt_token(
174175 mock_workload_token = Mock ()
175176 mock_workload_token .workload_access_token = "workload_token_123"
176177 mock_workload_token .expires_at = 1234567890
177- mock_identity_client .get_workload_access_token .return_value = mock_workload_token
178+ mock_identity_client .get_workload_access_token .return_value = (
179+ mock_workload_token
180+ )
178181
179182 middleware = A2AAuthMiddleware (
180183 app = app ,
@@ -196,10 +199,14 @@ async def mock_call_next(request):
196199 return Response ("OK" , status_code = 200 )
197200
198201 # Execute dispatch
199- with patch ("veadk.a2a.ve_a2a_middlewares.extract_delegation_chain_from_jwt" ) as mock_extract :
202+ with patch (
203+ "veadk.a2a.ve_middlewares.extract_delegation_chain_from_jwt"
204+ ) as mock_extract :
200205 mock_extract .return_value = ("user123" , ["agent1" ])
201206
202- with patch ("veadk.a2a.ve_a2a_middlewares.build_auth_config" ) as mock_build_config :
207+ with patch (
208+ "veadk.a2a.ve_middlewares.build_auth_config"
209+ ) as mock_build_config :
203210 mock_auth_config = Mock ()
204211 mock_auth_config .exchanged_auth_credential = Mock ()
205212 mock_build_config .return_value = mock_auth_config
@@ -228,7 +235,9 @@ async def test_dispatch_with_tip_token(
228235 mock_workload_token = Mock ()
229236 mock_workload_token .workload_access_token = "workload_token_from_tip"
230237 mock_workload_token .expires_at = 1234567890
231- mock_identity_client .get_workload_access_token .return_value = mock_workload_token
238+ mock_identity_client .get_workload_access_token .return_value = (
239+ mock_workload_token
240+ )
232241
233242 middleware = A2AAuthMiddleware (
234243 app = app ,
@@ -252,10 +261,14 @@ async def mock_call_next(request):
252261 return Response ("OK" , status_code = 200 )
253262
254263 # Execute dispatch
255- with patch ("veadk.a2a.ve_a2a_middlewares.extract_delegation_chain_from_jwt" ) as mock_extract :
264+ with patch (
265+ "veadk.a2a.ve_middlewares.extract_delegation_chain_from_jwt"
266+ ) as mock_extract :
256267 mock_extract .return_value = ("user123" , ["agent1" ])
257268
258- with patch ("veadk.a2a.ve_a2a_middlewares.build_auth_config" ) as mock_build_config :
269+ with patch (
270+ "veadk.a2a.ve_middlewares.build_auth_config"
271+ ) as mock_build_config :
259272 mock_auth_config = Mock ()
260273 mock_auth_config .exchanged_auth_credential = Mock ()
261274 mock_build_config .return_value = mock_auth_config
@@ -264,8 +277,7 @@ async def mock_call_next(request):
264277
265278 # Verify TIP token was used for workload token exchange
266279 mock_identity_client .get_workload_access_token .assert_called_once_with (
267- user_token = tip_token ,
268- user_id = "user123"
280+ user_token = tip_token , user_id = "user123"
269281 )
270282
271283 # Verify workload token was set
@@ -285,7 +297,9 @@ def test_build_middleware_basic(self, credential_service):
285297 assert middleware_class is not None
286298 assert issubclass (middleware_class , A2AAuthMiddleware )
287299
288- def test_build_middleware_with_all_params (self , credential_service , mock_identity_client ):
300+ def test_build_middleware_with_all_params (
301+ self , credential_service , mock_identity_client
302+ ):
289303 """Test building middleware with all parameters."""
290304 middleware_class = build_a2a_auth_middleware (
291305 app_name = "test_app" ,
@@ -305,4 +319,3 @@ def test_build_middleware_with_all_params(self, credential_service, mock_identit
305319 assert instance .token_param == "access_token"
306320 assert instance .credential_key == "custom_key"
307321 assert instance .identity_client == mock_identity_client
308-
0 commit comments