@@ -90,21 +90,27 @@ def jwks_data():
9090 key_size = 2048 ,
9191 )
9292
93- # Convert to JWK format
94- rsa_key = RSAKey (key = key , algorithm = "RS256" )
95- jwk_dict = rsa_key .to_dict ()
96- jwk_dict ["kid" ] = "test-kid-123"
97- jwk_dict ["use" ] = "sig"
98- jwk_dict ["alg" ] = "RS256"
99-
93+ # Convert to PEM format first
10094 private_pem = key .private_bytes (
10195 encoding = serialization .Encoding .PEM ,
10296 format = serialization .PrivateFormat .PKCS8 ,
10397 encryption_algorithm = serialization .NoEncryption (),
10498 )
10599
100+ # Convert to JWK format using PEM string
101+ rsa_key = RSAKey (key = private_pem .decode ("utf-8" ), algorithm = "RS256" )
102+ jwk_dict = rsa_key .to_dict ()
103+
104+ # Remove private key components to make it a public-only JWK
105+ public_jwk_dict = {
106+ k : v for k , v in jwk_dict .items () if k not in ["d" , "p" , "q" , "dp" , "dq" , "qi" ]
107+ }
108+ public_jwk_dict ["kid" ] = "test-kid-123"
109+ public_jwk_dict ["use" ] = "sig"
110+ public_jwk_dict ["alg" ] = "RS256"
111+
106112 return {
107- "keys" : [jwk_dict ],
113+ "keys" : [public_jwk_dict ],
108114 "private_key" : private_pem .decode ("utf-8" ),
109115 "kid" : "test-kid-123" ,
110116 }
@@ -181,9 +187,10 @@ async def test_jwks_cache_fetch_success(self, jwks_data):
181187 mock_response .json .return_value = {"keys" : jwks_data ["keys" ]}
182188 mock_response .raise_for_status .return_value = None
183189
184- mock_context = Mock ()
185- mock_context .__enter__ .return_value .get .return_value = mock_response
186- mock_client .return_value = mock_context
190+ mock_client .return_value .__enter__ .return_value .get .return_value = (
191+ mock_response
192+ )
193+ mock_client .return_value .__exit__ .return_value = None
187194
188195 result = cache .get_jwks (jwks_url )
189196
@@ -216,9 +223,10 @@ async def test_jwks_cache_refresh_expired(self, jwks_data):
216223 mock_response .json .return_value = {"keys" : jwks_data ["keys" ]}
217224 mock_response .raise_for_status .return_value = None
218225
219- mock_context = Mock ()
220- mock_context .__enter__ .return_value .get .return_value = mock_response
221- mock_client .return_value = mock_context
226+ mock_client .return_value .__enter__ .return_value .get .return_value = (
227+ mock_response
228+ )
229+ mock_client .return_value .__exit__ .return_value = None
222230
223231 result = cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
224232
@@ -232,11 +240,11 @@ async def test_jwks_cache_http_error(self):
232240 jwks_url = "https://test-issuer.com/.well-known/jwks.json"
233241
234242 with patch ("httpx.Client" ) as mock_client :
235- mock_context = Mock ()
236- mock_context . __enter__ . return_value . get .side_effect = httpx .HTTPError (
237- "Connection failed"
238- )
239- mock_client .return_value = mock_context
243+ mock_response = Mock ()
244+ mock_response . get .side_effect = httpx .HTTPError ("Connection failed" )
245+
246+ mock_client . return_value . __enter__ . return_value = mock_response
247+ mock_client .return_value . __exit__ . return_value = None
240248
241249 with pytest .raises (HTTPException ) as exc_info :
242250 cache .get_jwks (jwks_url )
@@ -407,8 +415,8 @@ def test_verify_jwt_success(self, mock_settings, jwks_data, valid_token):
407415 # Extract public key from JWKS data
408416 from jose .jwk import RSAKey
409417
410- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
411- mock_get_key .return_value = rsa_key .to_pem ()
418+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
419+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
412420
413421 result = verify_jwt (valid_token )
414422
@@ -437,8 +445,8 @@ def test_verify_jwt_expired_token(self, mock_settings, jwks_data, expired_token)
437445 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
438446 from jose .jwk import RSAKey
439447
440- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
441- mock_get_key .return_value = rsa_key .to_pem ()
448+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
449+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
442450
443451 with pytest .raises (HTTPException ) as exc_info :
444452 verify_jwt (expired_token )
@@ -470,8 +478,8 @@ def test_verify_jwt_future_token(self, mock_settings, jwks_data):
470478 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
471479 from jose .jwk import RSAKey
472480
473- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
474- mock_get_key .return_value = rsa_key .to_pem ()
481+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
482+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
475483
476484 with pytest .raises (HTTPException ) as exc_info :
477485 verify_jwt (future_token )
@@ -503,8 +511,8 @@ def test_verify_jwt_wrong_audience(self, mock_settings, jwks_data):
503511 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
504512 from jose .jwk import RSAKey
505513
506- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
507- mock_get_key .return_value = rsa_key .to_pem ()
514+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
515+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
508516
509517 with pytest .raises (HTTPException ) as exc_info :
510518 verify_jwt (wrong_aud_token )
@@ -536,8 +544,8 @@ def test_verify_jwt_audience_list(self, mock_settings, jwks_data):
536544 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
537545 from jose .jwk import RSAKey
538546
539- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
540- mock_get_key .return_value = rsa_key .to_pem ()
547+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
548+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
541549
542550 result = verify_jwt (list_aud_token )
543551 assert result .sub == "test-user-123"
@@ -565,8 +573,8 @@ def test_verify_jwt_missing_subject(self, mock_settings, jwks_data):
565573 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
566574 from jose .jwk import RSAKey
567575
568- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
569- mock_get_key .return_value = rsa_key .to_pem ()
576+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
577+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
570578
571579 with pytest .raises (HTTPException ) as exc_info :
572580 verify_jwt (no_sub_token )
@@ -599,8 +607,8 @@ def test_verify_jwt_scope_string_conversion(self, mock_settings, jwks_data):
599607 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
600608 from jose .jwk import RSAKey
601609
602- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
603- mock_get_key .return_value = rsa_key .to_pem ()
610+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
611+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
604612
605613 result = verify_jwt (scope_str_token )
606614 assert result .scope is None
@@ -630,8 +638,8 @@ def test_verify_jwt_roles_string_conversion(self, mock_settings, jwks_data):
630638 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
631639 from jose .jwk import RSAKey
632640
633- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
634- mock_get_key .return_value = rsa_key .to_pem ()
641+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
642+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
635643
636644 result = verify_jwt (roles_str_token )
637645 assert result .roles == ["admin" ]
@@ -659,8 +667,8 @@ def test_verify_jwt_no_audience_validation(self, mock_settings, jwks_data):
659667 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
660668 from jose .jwk import RSAKey
661669
662- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
663- mock_get_key .return_value = rsa_key .to_pem ()
670+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
671+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
664672
665673 result = verify_jwt (no_aud_token )
666674 assert result .sub == "test-user-123"
@@ -948,8 +956,8 @@ async def test_auth0_integration_scenario(self, mock_settings, jwks_data):
948956 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
949957 from jose .jwk import RSAKey
950958
951- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
952- mock_get_key .return_value = rsa_key .to_pem ()
959+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
960+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
953961
954962 result = verify_jwt (auth0_token )
955963
@@ -994,8 +1002,8 @@ async def test_aws_cognito_integration_scenario(self, mock_settings, jwks_data):
9941002 with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
9951003 from jose .jwk import RSAKey
9961004
997- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
998- mock_get_key .return_value = rsa_key .to_pem ()
1005+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
1006+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
9991007
10001008 result = verify_jwt (cognito_token )
10011009
@@ -1079,11 +1087,11 @@ async def test_network_timeout_handling(self, mock_settings):
10791087 cache = JWKSCache ()
10801088
10811089 with patch ("httpx.Client" ) as mock_client :
1082- mock_context = Mock ()
1083- mock_context . __enter__ . return_value . get .side_effect = (
1084- httpx . TimeoutException ( "Timeout" )
1085- )
1086- mock_client .return_value = mock_context
1090+ mock_response = Mock ()
1091+ mock_response . get .side_effect = httpx . TimeoutException ( "Timeout" )
1092+
1093+ mock_client . return_value . __enter__ . return_value = mock_response
1094+ mock_client .return_value . __exit__ . return_value = None
10871095
10881096 with pytest .raises (HTTPException ) as exc_info :
10891097 cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
@@ -1121,9 +1129,8 @@ async def test_jwks_endpoint_returns_invalid_json(self, mock_settings):
11211129 mock_response .json .side_effect = json .JSONDecodeError ("Invalid JSON" , "" , 0 )
11221130 mock_response .raise_for_status .return_value = None
11231131
1124- mock_context = Mock ()
1125- mock_context .__enter__ .return_value .get .return_value = mock_response
1126- mock_client .return_value = mock_context
1132+ mock_client .return_value .__enter__ .return_value = mock_response
1133+ mock_client .return_value .__exit__ .return_value = None
11271134
11281135 with pytest .raises (HTTPException ) as exc_info :
11291136 cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
@@ -1147,9 +1154,10 @@ async def test_memory_pressure_scenarios(self, mock_settings, jwks_data):
11471154 mock_response .json .return_value = large_jwks
11481155 mock_response .raise_for_status .return_value = None
11491156
1150- mock_context = Mock ()
1151- mock_context .__enter__ .return_value .get .return_value = mock_response
1152- mock_client .return_value = mock_context
1157+ mock_client .return_value .__enter__ .return_value .get .return_value = (
1158+ mock_response
1159+ )
1160+ mock_client .return_value .__exit__ .return_value = None
11531161
11541162 # Should handle large responses gracefully
11551163 result = cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
0 commit comments