Skip to content

Commit 64c2ffd

Browse files
committed
refactor: Check jwks cache after entering the write lock
1 parent dfb8de4 commit 64c2ffd

File tree

3 files changed

+42
-28
lines changed

3 files changed

+42
-28
lines changed

supertokens_python/recipe/session/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
3535

36-
JWKCacheMaxAgeInMs = 6 * 1000 # 6s
36+
JWKCacheMaxAgeInMs = 60 * 1000 # 60s
3737
protected_props = [
3838
"sub",
3939
"iat",

supertokens_python/recipe/session/jwks.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class JWKSConfigType(TypedDict):
3333

3434
JWKSConfig: JWKSConfigType = {
3535
"cache_max_age": JWKCacheMaxAgeInMs,
36-
"request_timeout": 5000, # 5s
36+
"request_timeout": 10000, # 10s
3737
}
3838

3939

@@ -56,22 +56,36 @@ def reset_jwks_cache():
5656
cached_keys = None
5757

5858

59-
def get_cached_keys() -> Optional[CachedKeys]:
60-
with RWLockContext(mutex, read=True):
61-
if cached_keys is not None:
62-
# This means that we have valid JWKs for the given core path
63-
# We check if we need to refresh before returning
59+
def get_cached_keys() -> Optional[List[PyJWK]]:
60+
if cached_keys is not None:
61+
# This means that we have valid JWKs for the given core path
62+
# We check if we need to refresh before returning
63+
64+
# This means that the value in cache is not expired, in this case we return the cached value
65+
# Note that this also means that the SDK will not try to query any other core (if there are multiple)
66+
# if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
67+
# from the cores again after the entry in the cache is expired
68+
if cached_keys.is_fresh():
69+
if environ.get("SUPERTOKENS_ENV") == "testing":
70+
log_debug_message("Returning JWKS from cache")
71+
return cached_keys.keys
72+
73+
return None
74+
6475

65-
# This means that the value in cache is not expired, in this case we return the cached value
66-
# Note that this also means that the SDK will not try to query any other core (if there are multiple)
67-
# if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
68-
# from the cores again after the entry in the cache is expired
69-
if cached_keys.is_fresh():
70-
if environ.get("SUPERTOKENS_ENV") == "testing":
71-
log_debug_message("Returning JWKS from cache")
72-
return cached_keys
76+
def find_matching_keys(
77+
keys: Optional[List[PyJWK]], kid: Optional[str]
78+
) -> Optional[List[PyJWK]]:
79+
if kid is None or keys is None:
80+
# return all keys since the token does not have a kid
81+
return keys
7382

74-
return None
83+
# kid has been provided so filter the keys
84+
matching_keys = [key for key in keys if key.key_id == kid] # type: ignore
85+
if len(matching_keys) > 0:
86+
return matching_keys
87+
88+
return None
7589

7690

7791
def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
@@ -80,15 +94,9 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
8094
if environ.get("SUPERTOKENS_ENV") == "testing":
8195
log_debug_message("Called find_jwk_client")
8296

83-
keys_from_cache = get_cached_keys()
84-
if keys_from_cache is not None:
85-
if kid is None:
86-
# return all keys since the token does not have a kid
87-
return keys_from_cache.keys
88-
89-
# kid has been provided so filter the keys
90-
matching_keys = [key for key in keys_from_cache.keys if key.key_id == kid] # type: ignore
91-
if len(matching_keys) > 0:
97+
with RWLockContext(mutex, read=True):
98+
matching_keys = find_matching_keys(get_cached_keys(), kid)
99+
if matching_keys is not None:
92100
return matching_keys
93101
# otherwise unknown kid, will continue to reload the keys
94102

@@ -104,6 +112,12 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
104112
last_error: Exception = Exception("No valid JWKS found")
105113

106114
with RWLockContext(mutex, read=False):
115+
# check again if the keys are in cache
116+
# because another thread might have fetched the keys while this one was waiting for the lock
117+
matching_keys = find_matching_keys(get_cached_keys(), kid)
118+
if matching_keys is not None:
119+
return matching_keys
120+
107121
for path in core_paths:
108122
if environ.get("SUPERTOKENS_ENV") == "testing":
109123
log_debug_message("Attempting to fetch JWKS from path: %s", path)

supertokens_python/recipe/session/jwt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656

5757
def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
5858
splitted_input = jwt.split(".")
59-
TOKEN_V3 = 3
59+
LATEST_TOKEN_VERSION = 3
6060
if len(splitted_input) != 3:
6161
raise Exception("invalid jwt")
6262

@@ -70,7 +70,7 @@ def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
7070
# checking the header
7171
if header not in _allowed_headers:
7272
parsed_header = loads(utf_base64decode(header, True))
73-
header_version = parsed_header.get("version", str(TOKEN_V3))
73+
header_version = parsed_header.get("version", str(LATEST_TOKEN_VERSION))
7474

7575
try:
7676
version = int(header_version)
@@ -82,7 +82,7 @@ def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
8282
if (
8383
parsed_header["typ"] != "JWT"
8484
or not isinstance(version, int)
85-
or version < TOKEN_V3
85+
or version < 3
8686
or kid is None
8787
):
8888
raise Exception("JWT header mismatch")

0 commit comments

Comments
 (0)