Skip to content

Commit 11a8ae0

Browse files
committed
refactor: Remove JWKClient and use RWMutex to protect keys in cache
1 parent 0ff1e7d commit 11a8ae0

File tree

8 files changed

+186
-233
lines changed

8 files changed

+186
-233
lines changed

supertokens_python/recipe/session/access_token.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:
4242
return None
4343

4444

45-
from supertokens_python.recipe.session.jwks import JWKClient
45+
from supertokens_python.recipe.session.jwks import get_latest_keys
4646

4747

4848
def get_info_from_access_token(
4949
jwt_info: ParsedJWTInfo,
50-
jwk_client: JWKClient,
5150
do_anti_csrf_check: bool,
5251
):
5352
try:
@@ -59,19 +58,17 @@ def get_info_from_access_token(
5958
)
6059

6160
if jwt_info.version >= 3:
62-
matching_key = jwk_client.get_matching_key_from_jwt(
63-
jwt_info.raw_token_string
64-
)
61+
matching_keys = get_latest_keys(jwt_info.kid)
6562
payload = jwt.decode( # type: ignore
6663
jwt_info.raw_token_string,
67-
matching_key.key, # type: ignore
64+
matching_keys[0].key, # type: ignore
6865
algorithms=[decode_algo],
6966
options={"verify_signature": True, "verify_exp": True},
7067
)
7168
else:
7269
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
7370
# If any of them work, we'll use that payload
74-
for k in jwk_client.get_latest_keys():
71+
for k in get_latest_keys():
7572
try:
7673
payload = jwt.decode( # type: ignore
7774
jwt_info.raw_token_string,

supertokens_python/recipe/session/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
3535

36-
JWKCacheMaxAgeInMs = 60 * 1000 # 1min
36+
JWKCacheMaxAgeInMs = 6 * 1000 # 6s
3737
JWKRequestCooldownInMs = 500 # 0.5s
3838
protected_props = [
3939
"sub",

supertokens_python/recipe/session/interfaces.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
if TYPE_CHECKING:
3838
from supertokens_python.framework import BaseRequest
39-
from .jwks import JWKClient
4039

4140
from supertokens_python.framework import BaseResponse
4241

@@ -128,8 +127,6 @@ class GetSessionTokensDangerouslyDict(TypedDict):
128127

129128

130129
class RecipeInterface(ABC): # pylint: disable=too-many-public-methods
131-
JWK_clients: List[JWKClient] = []
132-
133130
def __init__(self):
134131
pass
135132

supertokens_python/recipe/session/jwks.py

Lines changed: 107 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,133 @@
1-
import json
1+
# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
215
import requests
16+
from os import environ
317
from typing import List, Optional
18+
from typing_extensions import TypedDict
419

520
from jwt import PyJWK, PyJWKSet
6-
from jwt.api_jwt import decode_complete as decode_token # type: ignore
7-
8-
from supertokens_python.utils import get_timestamp_ms
921

1022
from .constants import JWKCacheMaxAgeInMs, JWKRequestCooldownInMs
1123

24+
from supertokens_python.utils import RWMutex, RWLockContext, get_timestamp_ms
25+
from supertokens_python.querier import Querier
1226
from supertokens_python.logger import log_debug_message
1327

1428

15-
class JWKClient:
16-
def __init__(
17-
self,
18-
uri: str,
19-
cooldown_duration: int = JWKRequestCooldownInMs,
20-
cache_max_age: int = JWKCacheMaxAgeInMs,
21-
):
22-
"""A client for retrieving JSON Web Key Sets (JWKS) from a given URI.
23-
24-
Args:
25-
uri (str): The URI of the JWKS.
26-
cooldown_duration (int, optional): The cooldown duration in ms. Defaults to 500 seconds.
27-
cache_max_age (int, optional): The cache max age in ms. Defaults to 5 minutes.
28-
29-
Note: The JSON Web Key Set is fetched when no key matches the selection
30-
process but only as frequently as the `self.cooldown_duration` option
31-
allows to prevent abuse. The `self.cache_max_age` option is used to
32-
determine how long the JWKS is cached for.
33-
34-
Whenever you make a call to `get_signing_key_from_jwt`, the JWKS
35-
is fetched if it is older than `self.cache_max_age` ms unless
36-
cooldown is active.
37-
"""
38-
self.uri = uri
39-
self.cooldown_duration = cooldown_duration
40-
self.cache_max_age = cache_max_age
41-
self.timeout_sec = 5
42-
self.last_fetch_time: int = 0
43-
self.cached_jwks: Optional[PyJWKSet] = None
44-
45-
def fetch(self):
46-
try:
47-
log_debug_message("Fetching jwk set from the configured uri")
48-
with requests.get(self.uri, timeout=self.timeout_sec) as response:
49-
response.raise_for_status()
50-
self.cached_jwks = PyJWKSet.from_dict(json.load(response)) # type: ignore
51-
self.last_fetch_time = get_timestamp_ms()
52-
except requests.HTTPError:
53-
raise JWKSRequestError("Failed to fetch jwk set from the configured uri")
54-
55-
def is_cooling_down(self) -> bool:
56-
return (self.last_fetch_time > 0) and (
57-
get_timestamp_ms() - self.last_fetch_time < self.cooldown_duration
58-
)
29+
class JWKSConfigType(TypedDict):
30+
cache_max_age: int
31+
refresh_rate_limit: int
32+
request_timeout: int
33+
34+
35+
JWKSConfig: JWKSConfigType = {
36+
"cache_max_age": JWKCacheMaxAgeInMs,
37+
"refresh_rate_limit": JWKRequestCooldownInMs, # FIXME: Not used
38+
"request_timeout": 5000, # 5s
39+
}
5940

60-
def is_fresh(self) -> bool:
61-
return (self.last_fetch_time > 0) and (
62-
get_timestamp_ms() - self.last_fetch_time < self.cache_max_age
41+
42+
class CachedKeys:
43+
def __init__(self, path: str, keys: List[PyJWK]):
44+
self.path = path
45+
self.keys = keys
46+
self.last_refresh_time = get_timestamp_ms()
47+
48+
def is_fresh(self):
49+
return (
50+
get_timestamp_ms() - self.last_refresh_time < JWKSConfig["cache_max_age"]
6351
)
6452

65-
def get_latest_keys(self) -> List[PyJWK]:
66-
if self.cached_jwks is None or not self.is_fresh():
67-
self.fetch()
6853

69-
if self.cached_jwks is None:
70-
raise JWKSRequestError("Failed to fetch the latest keys")
54+
cached_keys: Optional[CachedKeys] = None
55+
mutex = RWMutex()
7156

72-
all_keys: List[PyJWK] = self.cached_jwks.keys # type: ignore
57+
# only for testing purposes
58+
def reset_jwks_cache():
59+
with RWLockContext(mutex, read=False):
60+
global cached_keys
61+
cached_keys = None
7362

74-
return all_keys
7563

76-
def get_matching_key_from_jwt(self, token: str) -> PyJWK:
77-
header = decode_token(token, options={"verify_signature": False})["header"]
78-
kid: str = header["kid"] # type: ignore
64+
def get_cached_keys() -> Optional[CachedKeys]:
65+
with RWLockContext(mutex, read=True):
66+
if cached_keys is not None:
67+
# This means that we have valid JWKs for the given core path
68+
# We check if we need to refresh before returning
7969

80-
if self.cached_jwks is None or not self.is_fresh():
81-
self.fetch()
70+
# This means that the value in cache is not expired, in this case we return the cached value
71+
# Note that this also means that the SDK will not try to query any other core (if there are multiple)
72+
# if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
73+
# from the cores again after the entry in the cache is expired
74+
if cached_keys.is_fresh():
75+
if environ.get("SUPERTOKENS_ENV") == "testing":
76+
log_debug_message("Returning JWKS from cache")
77+
return cached_keys
8278

83-
assert self.cached_jwks is not None
79+
return None
8480

85-
try:
86-
return self.cached_jwks[kid] # type: ignore
87-
except KeyError:
88-
if not self.is_cooling_down():
89-
# One more attempt to fetch the latest keys
90-
# and then try to find the key again.
91-
self.fetch()
92-
try:
93-
return self.cached_jwks[kid] # type: ignore
94-
except KeyError:
95-
pass
96-
except Exception:
97-
raise JWKSKeyNotFoundError("No key found for the given kid")
9881

99-
raise JWKSKeyNotFoundError("No key found for the given kid")
82+
def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
83+
global cached_keys
10084

85+
if environ.get("SUPERTOKENS_ENV") == "testing":
86+
log_debug_message("Called find_jwk_client")
10187

102-
class JWKSKeyNotFoundError(Exception):
103-
pass
88+
keys_from_cache = get_cached_keys()
89+
if keys_from_cache is not None:
90+
if kid is None:
91+
# return all keys since the token does not have a kid
92+
return keys_from_cache.keys
93+
94+
# kid has been provided so filter the keys
95+
matching_keys = [key for key in keys_from_cache.keys if key.key_id == kid] # type: ignore
96+
if len(matching_keys) > 0:
97+
return matching_keys
98+
# otherwise unknown kid, will continue to reload the keys
99+
100+
core_paths = Querier.get_instance().get_all_core_urls_for_path(
101+
"./.well-known/jwks.json"
102+
)
103+
104+
if len(core_paths) == 0:
105+
raise Exception(
106+
"No SuperTokens core available to query. Please pass supertokens > connection_uri to the init function, or override all the functions of the recipe you are using."
107+
)
108+
109+
last_error: Exception = Exception("No valid JWKS found")
110+
111+
with RWLockContext(mutex, read=False):
112+
for path in core_paths:
113+
if environ.get("SUPERTOKENS_ENV") == "testing":
114+
log_debug_message("Attempting to fetch JWKS from path: %s", path)
115+
116+
cached_jwks: Optional[List[PyJWK]] = None
117+
try:
118+
log_debug_message("Fetching jwk set from the configured uri")
119+
with requests.get(path, timeout=JWKSConfig['request_timeout']/1000) as response: # 5 second timeout
120+
response.raise_for_status()
121+
cached_jwks = PyJWKSet.from_dict(response.json()).keys # type: ignore
122+
except Exception as e:
123+
last_error = e
124+
125+
if cached_jwks is not None: # we found a valid JWKS
126+
cached_keys = CachedKeys(path, cached_jwks)
127+
log_debug_message("Returning JWKS from fetch")
128+
return cached_keys.keys
129+
130+
raise last_error
104131

105132

106133
class JWKSRequestError(Exception):

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from supertokens_python.logger import log_debug_message
2020
from supertokens_python.normalised_url_path import NormalisedURLPath
21-
from supertokens_python.utils import resolve, RWMutex, RWLockContext
21+
from supertokens_python.utils import resolve
2222

2323
from ...types import MaybeAwaitable
2424
from . import session_functions
@@ -38,7 +38,6 @@
3838
SessionInformationResult,
3939
SessionObj,
4040
)
41-
from .jwks import JWKClient
4241
from .jwt import ParsedJWTInfo, parse_jwt_without_signature_verification
4342
from .session_class import Session
4443
from .utils import SessionConfig, validate_claims_in_payload
@@ -51,90 +50,6 @@
5150
from supertokens_python.querier import Querier
5251

5352

54-
from typing_extensions import TypedDict
55-
from os import environ
56-
57-
58-
class JWKSConfigType(TypedDict):
59-
cache_max_age: int
60-
refresh_rate_limit: int
61-
62-
63-
JWKSConfig: JWKSConfigType = {
64-
"cache_max_age": 6000,
65-
"refresh_rate_limit": 500,
66-
}
67-
68-
cached_jwk_client: Optional[JWKClient] = None
69-
mutex = RWMutex()
70-
71-
# only for testing purposes
72-
def reset_jwks_cache():
73-
with RWLockContext(mutex, read=False):
74-
global cached_jwk_client
75-
cached_jwk_client = None
76-
77-
78-
def get_jwk_client_from_cache() -> Optional[JWKClient]:
79-
with RWLockContext(mutex, read=True):
80-
if cached_jwk_client is not None:
81-
# This means that we have valid JWKs for the given core path
82-
# We check if we need to refresh before returning
83-
84-
# This means that the value in cache is not expired, in this case we return the cached value
85-
# Note that this also means that the SDK will not try to query any other Core (if there are multiple)
86-
# if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
87-
# from the cores again after the entry in the cache is expired
88-
if cached_jwk_client.is_fresh():
89-
if environ.get("SUPERTOKENS_ENV") == "testing":
90-
log_debug_message("Returning JWKS from cache")
91-
return cached_jwk_client
92-
93-
return None
94-
95-
96-
def find_jwk_client() -> JWKClient:
97-
global cached_jwk_client
98-
99-
if environ.get("SUPERTOKENS_ENV") == "testing":
100-
log_debug_message("Called find_jwk_client")
101-
102-
core_paths = Querier.get_instance().get_all_core_urls_for_path(
103-
"./.well-known/jwks.json"
104-
)
105-
106-
if len(core_paths) == 0:
107-
raise Exception(
108-
"No SuperTokens core available to query. Please pass supertokens > connection_uri to the init function, or override all the functions of the recipe you are using."
109-
)
110-
111-
client_from_cache = get_jwk_client_from_cache()
112-
if client_from_cache is not None:
113-
return client_from_cache
114-
115-
last_error: Exception = Exception("No valid JWKS found")
116-
117-
with RWLockContext(mutex, read=False):
118-
for path in core_paths:
119-
if environ.get("SUPERTOKENS_ENV") == "testing":
120-
log_debug_message("Attempting to fetch JWKS from path: %s", path)
121-
122-
client = JWKClient(
123-
path, JWKSConfig["refresh_rate_limit"], JWKSConfig["cache_max_age"]
124-
)
125-
try:
126-
client.fetch()
127-
except Exception as e:
128-
last_error = e
129-
130-
if client.cached_jwks is not None: # we found a valid JWKS
131-
cached_jwk_client = client
132-
log_debug_message("Returning JWKS from fetch")
133-
return cached_jwk_client
134-
135-
raise last_error
136-
137-
13853
class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods
13954
def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo):
14055
super().__init__()

supertokens_python/recipe/session/session_functions.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,8 @@ async def get_session(
150150
access_token_info: Optional[Dict[str, Any]] = None
151151

152152
try:
153-
from supertokens_python.recipe.session.recipe_implementation import (
154-
find_jwk_client,
155-
)
156-
157-
jwk_client = find_jwk_client()
158153
access_token_info = get_info_from_access_token(
159154
parsed_access_token,
160-
jwk_client,
161155
config.anti_csrf == "VIA_TOKEN" and do_anti_csrf_check,
162156
)
163157

0 commit comments

Comments
 (0)