Skip to content

Commit 33be81a

Browse files
committed
test: Add test for invalid kid that doesnt match on refreshing
1 parent 64c2ffd commit 33be81a

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

supertokens_python/recipe/session/jwks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
136136
if cached_jwks is not None: # we found a valid JWKS
137137
cached_keys = CachedKeys(cached_jwks)
138138
log_debug_message("Returning JWKS from fetch")
139-
return cached_keys.keys
139+
matching_keys = find_matching_keys(get_cached_keys(), kid)
140+
if matching_keys is not None:
141+
return matching_keys
142+
143+
raise Exception("No matching JWKS found")
140144

141145
raise last_error

supertokens_python/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,10 @@ def __enter__(self):
353353
self.mutex.lock()
354354

355355
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
356-
if exc_type is not None:
357-
raise exc_type(exc_value).with_traceback(traceback)
358-
359356
if self.read:
360357
self.mutex.r_unlock()
361358
else:
362359
self.mutex.unlock()
360+
361+
if exc_type is not None:
362+
raise exc_type(exc_value).with_traceback(traceback)

tests/sessions/test_jwks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import logging
44
import threading
5+
import json
56
import requests
67

78
from typing import List, Any, Callable
@@ -29,6 +30,7 @@
2930
get_cached_keys,
3031
get_latest_keys,
3132
)
33+
from supertokens_python.utils import utf_base64encode
3234

3335
from _pytest.logging import LogCaptureFixture
3436

@@ -198,6 +200,26 @@ async def test_that_jwks_are_refresh_if_kid_is_unknown(caplog: LogCaptureFixture
198200

199201
assert next(well_known_count) == 2 # no change
200202

203+
# use an access token with an invalid kid that doesn't match even on refreshing
204+
_, payload, signature = tokens["accessToken"].split(".")
205+
new_header = utf_base64encode(
206+
json.dumps(
207+
{"kid": "d-1234567890123", "typ": "JWT", "version": "3", "alg": "RS256"}
208+
),
209+
urlsafe=False,
210+
)
211+
new_token = ".".join([new_header, payload, signature])
212+
213+
with pytest.raises(Exception) as e:
214+
await get_session_without_request_response(
215+
new_token, tokens.get("antiCsrfToken")
216+
)
217+
218+
assert str(e.value) == "No matching JWKS found"
219+
assert (
220+
next(well_known_count) == 3
221+
) # kid not found in cache, so should have refreshed
222+
201223

202224
async def test_that_invalid_connection_uri_doesnot_throw_during_init_for_jwks():
203225
"""This test makes sure that initialising SuperTokens and Session with an invalid connection uri does not

0 commit comments

Comments
 (0)