@@ -90,21 +90,27 @@ def jwks_data():
90
90
key_size = 2048 ,
91
91
)
92
92
93
- # Convert to JWK format
94
- rsa_key = RSAKey (key = key , algorithm = "RS256" )
95
- jwk_dict = rsa_key .to_dict ()
96
- jwk_dict ["kid" ] = "test-kid-123"
97
- jwk_dict ["use" ] = "sig"
98
- jwk_dict ["alg" ] = "RS256"
99
-
93
+ # Convert to PEM format first
100
94
private_pem = key .private_bytes (
101
95
encoding = serialization .Encoding .PEM ,
102
96
format = serialization .PrivateFormat .PKCS8 ,
103
97
encryption_algorithm = serialization .NoEncryption (),
104
98
)
105
99
100
+ # Convert to JWK format using PEM string
101
+ rsa_key = RSAKey (key = private_pem .decode ("utf-8" ), algorithm = "RS256" )
102
+ jwk_dict = rsa_key .to_dict ()
103
+
104
+ # Remove private key components to make it a public-only JWK
105
+ public_jwk_dict = {
106
+ k : v for k , v in jwk_dict .items () if k not in ["d" , "p" , "q" , "dp" , "dq" , "qi" ]
107
+ }
108
+ public_jwk_dict ["kid" ] = "test-kid-123"
109
+ public_jwk_dict ["use" ] = "sig"
110
+ public_jwk_dict ["alg" ] = "RS256"
111
+
106
112
return {
107
- "keys" : [jwk_dict ],
113
+ "keys" : [public_jwk_dict ],
108
114
"private_key" : private_pem .decode ("utf-8" ),
109
115
"kid" : "test-kid-123" ,
110
116
}
@@ -181,9 +187,10 @@ async def test_jwks_cache_fetch_success(self, jwks_data):
181
187
mock_response .json .return_value = {"keys" : jwks_data ["keys" ]}
182
188
mock_response .raise_for_status .return_value = None
183
189
184
- mock_context = Mock ()
185
- mock_context .__enter__ .return_value .get .return_value = mock_response
186
- mock_client .return_value = mock_context
190
+ mock_client .return_value .__enter__ .return_value .get .return_value = (
191
+ mock_response
192
+ )
193
+ mock_client .return_value .__exit__ .return_value = None
187
194
188
195
result = cache .get_jwks (jwks_url )
189
196
@@ -216,9 +223,10 @@ async def test_jwks_cache_refresh_expired(self, jwks_data):
216
223
mock_response .json .return_value = {"keys" : jwks_data ["keys" ]}
217
224
mock_response .raise_for_status .return_value = None
218
225
219
- mock_context = Mock ()
220
- mock_context .__enter__ .return_value .get .return_value = mock_response
221
- mock_client .return_value = mock_context
226
+ mock_client .return_value .__enter__ .return_value .get .return_value = (
227
+ mock_response
228
+ )
229
+ mock_client .return_value .__exit__ .return_value = None
222
230
223
231
result = cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
224
232
@@ -232,11 +240,11 @@ async def test_jwks_cache_http_error(self):
232
240
jwks_url = "https://test-issuer.com/.well-known/jwks.json"
233
241
234
242
with patch ("httpx.Client" ) as mock_client :
235
- mock_context = Mock ()
236
- mock_context . __enter__ . return_value . get .side_effect = httpx .HTTPError (
237
- "Connection failed"
238
- )
239
- mock_client .return_value = mock_context
243
+ mock_response = Mock ()
244
+ mock_response . get .side_effect = httpx .HTTPError ("Connection failed" )
245
+
246
+ mock_client . return_value . __enter__ . return_value = mock_response
247
+ mock_client .return_value . __exit__ . return_value = None
240
248
241
249
with pytest .raises (HTTPException ) as exc_info :
242
250
cache .get_jwks (jwks_url )
@@ -407,8 +415,8 @@ def test_verify_jwt_success(self, mock_settings, jwks_data, valid_token):
407
415
# Extract public key from JWKS data
408
416
from jose .jwk import RSAKey
409
417
410
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
411
- mock_get_key .return_value = rsa_key .to_pem ()
418
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
419
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
412
420
413
421
result = verify_jwt (valid_token )
414
422
@@ -437,8 +445,8 @@ def test_verify_jwt_expired_token(self, mock_settings, jwks_data, expired_token)
437
445
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
438
446
from jose .jwk import RSAKey
439
447
440
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
441
- mock_get_key .return_value = rsa_key .to_pem ()
448
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
449
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
442
450
443
451
with pytest .raises (HTTPException ) as exc_info :
444
452
verify_jwt (expired_token )
@@ -470,8 +478,8 @@ def test_verify_jwt_future_token(self, mock_settings, jwks_data):
470
478
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
471
479
from jose .jwk import RSAKey
472
480
473
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
474
- mock_get_key .return_value = rsa_key .to_pem ()
481
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
482
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
475
483
476
484
with pytest .raises (HTTPException ) as exc_info :
477
485
verify_jwt (future_token )
@@ -503,8 +511,8 @@ def test_verify_jwt_wrong_audience(self, mock_settings, jwks_data):
503
511
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
504
512
from jose .jwk import RSAKey
505
513
506
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
507
- mock_get_key .return_value = rsa_key .to_pem ()
514
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
515
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
508
516
509
517
with pytest .raises (HTTPException ) as exc_info :
510
518
verify_jwt (wrong_aud_token )
@@ -536,8 +544,8 @@ def test_verify_jwt_audience_list(self, mock_settings, jwks_data):
536
544
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
537
545
from jose .jwk import RSAKey
538
546
539
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
540
- mock_get_key .return_value = rsa_key .to_pem ()
547
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
548
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
541
549
542
550
result = verify_jwt (list_aud_token )
543
551
assert result .sub == "test-user-123"
@@ -565,8 +573,8 @@ def test_verify_jwt_missing_subject(self, mock_settings, jwks_data):
565
573
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
566
574
from jose .jwk import RSAKey
567
575
568
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
569
- mock_get_key .return_value = rsa_key .to_pem ()
576
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
577
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
570
578
571
579
with pytest .raises (HTTPException ) as exc_info :
572
580
verify_jwt (no_sub_token )
@@ -599,8 +607,8 @@ def test_verify_jwt_scope_string_conversion(self, mock_settings, jwks_data):
599
607
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
600
608
from jose .jwk import RSAKey
601
609
602
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
603
- mock_get_key .return_value = rsa_key .to_pem ()
610
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
611
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
604
612
605
613
result = verify_jwt (scope_str_token )
606
614
assert result .scope is None
@@ -630,8 +638,8 @@ def test_verify_jwt_roles_string_conversion(self, mock_settings, jwks_data):
630
638
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
631
639
from jose .jwk import RSAKey
632
640
633
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
634
- mock_get_key .return_value = rsa_key .to_pem ()
641
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
642
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
635
643
636
644
result = verify_jwt (roles_str_token )
637
645
assert result .roles == ["admin" ]
@@ -659,8 +667,8 @@ def test_verify_jwt_no_audience_validation(self, mock_settings, jwks_data):
659
667
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
660
668
from jose .jwk import RSAKey
661
669
662
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
663
- mock_get_key .return_value = rsa_key .to_pem ()
670
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
671
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
664
672
665
673
result = verify_jwt (no_aud_token )
666
674
assert result .sub == "test-user-123"
@@ -948,8 +956,8 @@ async def test_auth0_integration_scenario(self, mock_settings, jwks_data):
948
956
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
949
957
from jose .jwk import RSAKey
950
958
951
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
952
- mock_get_key .return_value = rsa_key .to_pem ()
959
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
960
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
953
961
954
962
result = verify_jwt (auth0_token )
955
963
@@ -994,8 +1002,8 @@ async def test_aws_cognito_integration_scenario(self, mock_settings, jwks_data):
994
1002
with patch ("agent_memory_server.auth.get_public_key" ) as mock_get_key :
995
1003
from jose .jwk import RSAKey
996
1004
997
- rsa_key = RSAKey (jwks_data ["keys" ][0 ])
998
- mock_get_key .return_value = rsa_key .to_pem ()
1005
+ rsa_key = RSAKey (jwks_data ["keys" ][0 ], algorithm = "RS256" )
1006
+ mock_get_key .return_value = rsa_key .to_pem (). decode ( "utf-8" )
999
1007
1000
1008
result = verify_jwt (cognito_token )
1001
1009
@@ -1079,11 +1087,11 @@ async def test_network_timeout_handling(self, mock_settings):
1079
1087
cache = JWKSCache ()
1080
1088
1081
1089
with patch ("httpx.Client" ) as mock_client :
1082
- mock_context = Mock ()
1083
- mock_context . __enter__ . return_value . get .side_effect = (
1084
- httpx . TimeoutException ( "Timeout" )
1085
- )
1086
- mock_client .return_value = mock_context
1090
+ mock_response = Mock ()
1091
+ mock_response . get .side_effect = httpx . TimeoutException ( "Timeout" )
1092
+
1093
+ mock_client . return_value . __enter__ . return_value = mock_response
1094
+ mock_client .return_value . __exit__ . return_value = None
1087
1095
1088
1096
with pytest .raises (HTTPException ) as exc_info :
1089
1097
cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
@@ -1121,9 +1129,8 @@ async def test_jwks_endpoint_returns_invalid_json(self, mock_settings):
1121
1129
mock_response .json .side_effect = json .JSONDecodeError ("Invalid JSON" , "" , 0 )
1122
1130
mock_response .raise_for_status .return_value = None
1123
1131
1124
- mock_context = Mock ()
1125
- mock_context .__enter__ .return_value .get .return_value = mock_response
1126
- mock_client .return_value = mock_context
1132
+ mock_client .return_value .__enter__ .return_value = mock_response
1133
+ mock_client .return_value .__exit__ .return_value = None
1127
1134
1128
1135
with pytest .raises (HTTPException ) as exc_info :
1129
1136
cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
@@ -1147,9 +1154,10 @@ async def test_memory_pressure_scenarios(self, mock_settings, jwks_data):
1147
1154
mock_response .json .return_value = large_jwks
1148
1155
mock_response .raise_for_status .return_value = None
1149
1156
1150
- mock_context = Mock ()
1151
- mock_context .__enter__ .return_value .get .return_value = mock_response
1152
- mock_client .return_value = mock_context
1157
+ mock_client .return_value .__enter__ .return_value .get .return_value = (
1158
+ mock_response
1159
+ )
1160
+ mock_client .return_value .__exit__ .return_value = None
1153
1161
1154
1162
# Should handle large responses gracefully
1155
1163
result = cache .get_jwks ("https://test-issuer.com/.well-known/jwks.json" )
0 commit comments