Skip to content

Commit 01fe346

Browse files
committed
Fix auth tests
1 parent c750097 commit 01fe346

File tree

3 files changed

+72
-54
lines changed

3 files changed

+72
-54
lines changed

agent_memory_server/auth.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
class UserInfo(BaseModel):
2020
sub: str
21-
aud: str | None = None
21+
aud: str | list[str] | None = None
2222
scope: str | None = None
2323
exp: int | None = None
2424
iat: int | None = None
@@ -128,7 +128,8 @@ def get_public_key(token: str) -> str:
128128
for key in keys:
129129
if key.get("kid") == kid:
130130
try:
131-
public_key = jwk.construct(key).to_pem()
131+
public_key_bytes = jwk.construct(key).to_pem()
132+
public_key = public_key_bytes.decode("utf-8")
132133
break
133134
except Exception as e:
134135
logger.error("Failed to construct public key", kid=kid, error=str(e))

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,15 @@ async def mock_get_redis_conn(*args, **kwargs):
386386
return app
387387

388388

389+
@pytest.fixture(autouse=True)
390+
def disable_auth_for_tests():
391+
"""Disable authentication for all tests"""
392+
original_value = settings.disable_auth
393+
settings.disable_auth = True
394+
yield
395+
settings.disable_auth = original_value
396+
397+
389398
@pytest.fixture()
390399
async def client(app):
391400
async with AsyncClient(

tests/test_auth.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,27 @@ def jwks_data():
9090
key_size=2048,
9191
)
9292

93-
# Convert to JWK format
94-
rsa_key = RSAKey(key=key, algorithm="RS256")
95-
jwk_dict = rsa_key.to_dict()
96-
jwk_dict["kid"] = "test-kid-123"
97-
jwk_dict["use"] = "sig"
98-
jwk_dict["alg"] = "RS256"
99-
93+
# Convert to PEM format first
10094
private_pem = key.private_bytes(
10195
encoding=serialization.Encoding.PEM,
10296
format=serialization.PrivateFormat.PKCS8,
10397
encryption_algorithm=serialization.NoEncryption(),
10498
)
10599

100+
# Convert to JWK format using PEM string
101+
rsa_key = RSAKey(key=private_pem.decode("utf-8"), algorithm="RS256")
102+
jwk_dict = rsa_key.to_dict()
103+
104+
# Remove private key components to make it a public-only JWK
105+
public_jwk_dict = {
106+
k: v for k, v in jwk_dict.items() if k not in ["d", "p", "q", "dp", "dq", "qi"]
107+
}
108+
public_jwk_dict["kid"] = "test-kid-123"
109+
public_jwk_dict["use"] = "sig"
110+
public_jwk_dict["alg"] = "RS256"
111+
106112
return {
107-
"keys": [jwk_dict],
113+
"keys": [public_jwk_dict],
108114
"private_key": private_pem.decode("utf-8"),
109115
"kid": "test-kid-123",
110116
}
@@ -181,9 +187,10 @@ async def test_jwks_cache_fetch_success(self, jwks_data):
181187
mock_response.json.return_value = {"keys": jwks_data["keys"]}
182188
mock_response.raise_for_status.return_value = None
183189

184-
mock_context = Mock()
185-
mock_context.__enter__.return_value.get.return_value = mock_response
186-
mock_client.return_value = mock_context
190+
mock_client.return_value.__enter__.return_value.get.return_value = (
191+
mock_response
192+
)
193+
mock_client.return_value.__exit__.return_value = None
187194

188195
result = cache.get_jwks(jwks_url)
189196

@@ -216,9 +223,10 @@ async def test_jwks_cache_refresh_expired(self, jwks_data):
216223
mock_response.json.return_value = {"keys": jwks_data["keys"]}
217224
mock_response.raise_for_status.return_value = None
218225

219-
mock_context = Mock()
220-
mock_context.__enter__.return_value.get.return_value = mock_response
221-
mock_client.return_value = mock_context
226+
mock_client.return_value.__enter__.return_value.get.return_value = (
227+
mock_response
228+
)
229+
mock_client.return_value.__exit__.return_value = None
222230

223231
result = cache.get_jwks("https://test-issuer.com/.well-known/jwks.json")
224232

@@ -232,11 +240,11 @@ async def test_jwks_cache_http_error(self):
232240
jwks_url = "https://test-issuer.com/.well-known/jwks.json"
233241

234242
with patch("httpx.Client") as mock_client:
235-
mock_context = Mock()
236-
mock_context.__enter__.return_value.get.side_effect = httpx.HTTPError(
237-
"Connection failed"
238-
)
239-
mock_client.return_value = mock_context
243+
mock_response = Mock()
244+
mock_response.get.side_effect = httpx.HTTPError("Connection failed")
245+
246+
mock_client.return_value.__enter__.return_value = mock_response
247+
mock_client.return_value.__exit__.return_value = None
240248

241249
with pytest.raises(HTTPException) as exc_info:
242250
cache.get_jwks(jwks_url)
@@ -407,8 +415,8 @@ def test_verify_jwt_success(self, mock_settings, jwks_data, valid_token):
407415
# Extract public key from JWKS data
408416
from jose.jwk import RSAKey
409417

410-
rsa_key = RSAKey(jwks_data["keys"][0])
411-
mock_get_key.return_value = rsa_key.to_pem()
418+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
419+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
412420

413421
result = verify_jwt(valid_token)
414422

@@ -437,8 +445,8 @@ def test_verify_jwt_expired_token(self, mock_settings, jwks_data, expired_token)
437445
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
438446
from jose.jwk import RSAKey
439447

440-
rsa_key = RSAKey(jwks_data["keys"][0])
441-
mock_get_key.return_value = rsa_key.to_pem()
448+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
449+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
442450

443451
with pytest.raises(HTTPException) as exc_info:
444452
verify_jwt(expired_token)
@@ -470,8 +478,8 @@ def test_verify_jwt_future_token(self, mock_settings, jwks_data):
470478
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
471479
from jose.jwk import RSAKey
472480

473-
rsa_key = RSAKey(jwks_data["keys"][0])
474-
mock_get_key.return_value = rsa_key.to_pem()
481+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
482+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
475483

476484
with pytest.raises(HTTPException) as exc_info:
477485
verify_jwt(future_token)
@@ -503,8 +511,8 @@ def test_verify_jwt_wrong_audience(self, mock_settings, jwks_data):
503511
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
504512
from jose.jwk import RSAKey
505513

506-
rsa_key = RSAKey(jwks_data["keys"][0])
507-
mock_get_key.return_value = rsa_key.to_pem()
514+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
515+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
508516

509517
with pytest.raises(HTTPException) as exc_info:
510518
verify_jwt(wrong_aud_token)
@@ -536,8 +544,8 @@ def test_verify_jwt_audience_list(self, mock_settings, jwks_data):
536544
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
537545
from jose.jwk import RSAKey
538546

539-
rsa_key = RSAKey(jwks_data["keys"][0])
540-
mock_get_key.return_value = rsa_key.to_pem()
547+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
548+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
541549

542550
result = verify_jwt(list_aud_token)
543551
assert result.sub == "test-user-123"
@@ -565,8 +573,8 @@ def test_verify_jwt_missing_subject(self, mock_settings, jwks_data):
565573
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
566574
from jose.jwk import RSAKey
567575

568-
rsa_key = RSAKey(jwks_data["keys"][0])
569-
mock_get_key.return_value = rsa_key.to_pem()
576+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
577+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
570578

571579
with pytest.raises(HTTPException) as exc_info:
572580
verify_jwt(no_sub_token)
@@ -599,8 +607,8 @@ def test_verify_jwt_scope_string_conversion(self, mock_settings, jwks_data):
599607
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
600608
from jose.jwk import RSAKey
601609

602-
rsa_key = RSAKey(jwks_data["keys"][0])
603-
mock_get_key.return_value = rsa_key.to_pem()
610+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
611+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
604612

605613
result = verify_jwt(scope_str_token)
606614
assert result.scope is None
@@ -630,8 +638,8 @@ def test_verify_jwt_roles_string_conversion(self, mock_settings, jwks_data):
630638
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
631639
from jose.jwk import RSAKey
632640

633-
rsa_key = RSAKey(jwks_data["keys"][0])
634-
mock_get_key.return_value = rsa_key.to_pem()
641+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
642+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
635643

636644
result = verify_jwt(roles_str_token)
637645
assert result.roles == ["admin"]
@@ -659,8 +667,8 @@ def test_verify_jwt_no_audience_validation(self, mock_settings, jwks_data):
659667
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
660668
from jose.jwk import RSAKey
661669

662-
rsa_key = RSAKey(jwks_data["keys"][0])
663-
mock_get_key.return_value = rsa_key.to_pem()
670+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
671+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
664672

665673
result = verify_jwt(no_aud_token)
666674
assert result.sub == "test-user-123"
@@ -948,8 +956,8 @@ async def test_auth0_integration_scenario(self, mock_settings, jwks_data):
948956
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
949957
from jose.jwk import RSAKey
950958

951-
rsa_key = RSAKey(jwks_data["keys"][0])
952-
mock_get_key.return_value = rsa_key.to_pem()
959+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
960+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
953961

954962
result = verify_jwt(auth0_token)
955963

@@ -994,8 +1002,8 @@ async def test_aws_cognito_integration_scenario(self, mock_settings, jwks_data):
9941002
with patch("agent_memory_server.auth.get_public_key") as mock_get_key:
9951003
from jose.jwk import RSAKey
9961004

997-
rsa_key = RSAKey(jwks_data["keys"][0])
998-
mock_get_key.return_value = rsa_key.to_pem()
1005+
rsa_key = RSAKey(jwks_data["keys"][0], algorithm="RS256")
1006+
mock_get_key.return_value = rsa_key.to_pem().decode("utf-8")
9991007

10001008
result = verify_jwt(cognito_token)
10011009

@@ -1079,11 +1087,11 @@ async def test_network_timeout_handling(self, mock_settings):
10791087
cache = JWKSCache()
10801088

10811089
with patch("httpx.Client") as mock_client:
1082-
mock_context = Mock()
1083-
mock_context.__enter__.return_value.get.side_effect = (
1084-
httpx.TimeoutException("Timeout")
1085-
)
1086-
mock_client.return_value = mock_context
1090+
mock_response = Mock()
1091+
mock_response.get.side_effect = httpx.TimeoutException("Timeout")
1092+
1093+
mock_client.return_value.__enter__.return_value = mock_response
1094+
mock_client.return_value.__exit__.return_value = None
10871095

10881096
with pytest.raises(HTTPException) as exc_info:
10891097
cache.get_jwks("https://test-issuer.com/.well-known/jwks.json")
@@ -1121,9 +1129,8 @@ async def test_jwks_endpoint_returns_invalid_json(self, mock_settings):
11211129
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
11221130
mock_response.raise_for_status.return_value = None
11231131

1124-
mock_context = Mock()
1125-
mock_context.__enter__.return_value.get.return_value = mock_response
1126-
mock_client.return_value = mock_context
1132+
mock_client.return_value.__enter__.return_value = mock_response
1133+
mock_client.return_value.__exit__.return_value = None
11271134

11281135
with pytest.raises(HTTPException) as exc_info:
11291136
cache.get_jwks("https://test-issuer.com/.well-known/jwks.json")
@@ -1147,9 +1154,10 @@ async def test_memory_pressure_scenarios(self, mock_settings, jwks_data):
11471154
mock_response.json.return_value = large_jwks
11481155
mock_response.raise_for_status.return_value = None
11491156

1150-
mock_context = Mock()
1151-
mock_context.__enter__.return_value.get.return_value = mock_response
1152-
mock_client.return_value = mock_context
1157+
mock_client.return_value.__enter__.return_value.get.return_value = (
1158+
mock_response
1159+
)
1160+
mock_client.return_value.__exit__.return_value = None
11531161

11541162
# Should handle large responses gracefully
11551163
result = cache.get_jwks("https://test-issuer.com/.well-known/jwks.json")

0 commit comments

Comments
 (0)