Skip to content

Commit 5718300

Browse files
Merge pull request #337 from supertokens/refactor/jwks
refactor: Improve JWKS fetching and caching
2 parents 13dee35 + 0c3134d commit 5718300

File tree

14 files changed

+926
-157
lines changed

14 files changed

+926
-157
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11-
### Changes
11+
## [0.14.4] - 2023-06-14
12+
13+
### Changes and fixes
1214

15+
- Use `useStaticSigningKey` instead of `use_static_signing_key` in `create_jwt` function. This was a bug in the code.
16+
- Use request library instead of urllib to fetch JWKS keys ([#344](https://github.com/supertokens/supertokens-python/issues/344))
1317
- Throw error when `verify_sesion` is used with a view that allows `OPTIONS` or `TRACE` requests
1418
- Allow `verify_session` decorator to be with `@app.before_request` in Flask without returning a response
1519

supertokens_python/recipe/jwt/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def create_jwt(
5252
data = {
5353
"payload": payload,
5454
"validity": validity_seconds,
55-
"use_static_signing_key": use_static_signing_key is not False,
55+
"useStaticSigningKey": use_static_signing_key is not False,
5656
"algorithm": "RS256",
5757
"jwksDomain": self.app_info.api_domain.get_as_string_dangerous(),
5858
}

supertokens_python/recipe/session/access_token.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Dict, Optional, Union
1717

1818
import jwt
19-
from jwt.exceptions import DecodeError, PyJWKClientError
19+
from jwt.exceptions import DecodeError
2020

2121
from supertokens_python.logger import log_debug_message
2222
from supertokens_python.utils import get_timestamp_ms
2323

2424
from .exceptions import raise_try_refresh_token_exception
25-
from .jwks import JWKClient, JWKSRequestError, PyJWK
2625
from .jwt import ParsedJWTInfo
2726

2827

@@ -43,47 +42,44 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:
4342
return None
4443

4544

45+
from supertokens_python.recipe.session.jwks import get_latest_keys
46+
47+
4648
def get_info_from_access_token(
4749
jwt_info: ParsedJWTInfo,
48-
jwk_clients: List[JWKClient],
4950
do_anti_csrf_check: bool,
5051
):
5152
try:
5253
payload: Optional[Dict[str, Any]] = None
53-
client: Optional[JWKClient] = None
54-
keys: Optional[List[PyJWK]] = None
55-
56-
# Get the keys from the first available client
57-
for c in jwk_clients:
58-
try:
59-
client = c
60-
keys = c.get_latest_keys()
61-
break
62-
except JWKSRequestError:
63-
continue
64-
65-
if keys is None or client is None:
66-
raise PyJWKClientError("No key found")
67-
68-
if jwt_info.version < 3:
54+
decode_algo = (
55+
jwt_info.parsed_header["alg"]
56+
if jwt_info.parsed_header is not None
57+
else "RS256"
58+
)
59+
60+
if jwt_info.version >= 3:
61+
matching_keys = get_latest_keys(jwt_info.kid)
62+
payload = jwt.decode( # type: ignore
63+
jwt_info.raw_token_string,
64+
matching_keys[0].key, # type: ignore
65+
algorithms=[decode_algo],
66+
options={"verify_signature": True, "verify_exp": True},
67+
)
68+
else:
6969
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
7070
# If any of them work, we'll use that payload
71-
for k in keys:
71+
for k in get_latest_keys():
7272
try:
73-
payload = jwt.decode(jwt_info.raw_token_string, k.key, algorithms=["RS256"]) # type: ignore
73+
payload = jwt.decode( # type: ignore
74+
jwt_info.raw_token_string,
75+
k.key, # type: ignore
76+
algorithms=[decode_algo],
77+
options={"verify_signature": True, "verify_exp": True},
78+
)
7479
break
7580
except DecodeError:
7681
pass
7782

78-
elif jwt_info.version >= 3:
79-
matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
80-
payload = jwt.decode( # type: ignore
81-
jwt_info.raw_token_string,
82-
matching_key.key, # type: ignore
83-
algorithms=["RS256"],
84-
options={"verify_signature": True, "verify_exp": True},
85-
)
86-
8783
if payload is None:
8884
raise DecodeError("Could not decode the token")
8985

@@ -110,7 +106,7 @@ def get_info_from_access_token(
110106
if anti_csrf_token is None and do_anti_csrf_check:
111107
raise Exception("Access token does not contain the anti-csrf token")
112108

113-
assert isinstance(expiry_time, int)
109+
assert isinstance(expiry_time, (float, int))
114110

115111
if expiry_time < get_timestamp_ms():
116112
raise Exception("Access token expired")
@@ -137,8 +133,8 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
137133
if version >= 3:
138134
if (
139135
not isinstance(payload.get("sub"), str)
140-
or not isinstance(payload.get("exp"), int)
141-
or not isinstance(payload.get("iat"), int)
136+
or not isinstance(payload.get("exp"), (int, float))
137+
or not isinstance(payload.get("iat"), (int, float))
142138
or not isinstance(payload.get("sessionHandle"), str)
143139
or not isinstance(payload.get("refreshTokenHash1"), str)
144140
):
@@ -153,8 +149,8 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
153149
not isinstance(payload.get("sessionHandle"), str)
154150
or payload.get("userData") is None
155151
or not isinstance(payload.get("refreshTokenHash1"), str)
156-
or not isinstance(payload.get("expiryTime"), int)
157-
or not isinstance(payload.get("timeCreated"), int)
152+
or not isinstance(payload.get("expiryTime"), (float, int))
153+
or not isinstance(payload.get("timeCreated"), (float, int))
158154
):
159155
log_debug_message(
160156
"validateAccessTokenStructure: Access token is using version < 3"

supertokens_python/recipe/session/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333

3434
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
3535

36-
JWKCacheMaxAgeInMs = 60 * 1000 # 1min
37-
JWKRequestCooldownInMs = 500 # 0.5s
36+
JWKCacheMaxAgeInMs = 60 * 1000 # 60s
3837
protected_props = [
3938
"sub",
4039
"iat",

supertokens_python/recipe/session/interfaces.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
if TYPE_CHECKING:
3838
from supertokens_python.framework import BaseRequest
39-
from .jwks import JWKClient
4039

4140
from supertokens_python.framework import BaseResponse
4241

@@ -128,8 +127,6 @@ class GetSessionTokensDangerouslyDict(TypedDict):
128127

129128

130129
class RecipeInterface(ABC): # pylint: disable=too-many-public-methods
131-
JWK_clients: List[JWKClient] = []
132-
133130
def __init__(self):
134131
pass
135132

0 commit comments

Comments
 (0)