1313# under the License.
1414from __future__ import annotations
1515
16- from typing import Any , Dict , List , Optional , Union
16+ from typing import Any , Dict , Optional , Union
1717
1818import jwt
19- from jwt .exceptions import DecodeError , PyJWKClientError
19+ from jwt .exceptions import DecodeError
2020
2121from supertokens_python .logger import log_debug_message
2222from supertokens_python .utils import get_timestamp_ms
2323
2424from .exceptions import raise_try_refresh_token_exception
25- from .jwks import JWKClient , JWKSRequestError , PyJWK
2625from .jwt import ParsedJWTInfo
2726
2827
@@ -43,47 +42,44 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:
4342 return None
4443
4544
45+ from supertokens_python .recipe .session .jwks import get_latest_keys
46+
47+
4648def get_info_from_access_token (
4749 jwt_info : ParsedJWTInfo ,
48- jwk_clients : List [JWKClient ],
4950 do_anti_csrf_check : bool ,
5051):
5152 try :
5253 payload : Optional [Dict [str , Any ]] = None
53- client : Optional [JWKClient ] = None
54- keys : Optional [List [PyJWK ]] = None
55-
56- # Get the keys from the first available client
57- for c in jwk_clients :
58- try :
59- client = c
60- keys = c .get_latest_keys ()
61- break
62- except JWKSRequestError :
63- continue
64-
65- if keys is None or client is None :
66- raise PyJWKClientError ("No key found" )
67-
68- if jwt_info .version < 3 :
54+ decode_algo = (
55+ jwt_info .parsed_header ["alg" ]
56+ if jwt_info .parsed_header is not None
57+ else "RS256"
58+ )
59+
60+ if jwt_info .version >= 3 :
61+ matching_keys = get_latest_keys (jwt_info .kid )
62+ payload = jwt .decode ( # type: ignore
63+ jwt_info .raw_token_string ,
64+ matching_keys [0 ].key , # type: ignore
65+ algorithms = [decode_algo ],
66+ options = {"verify_signature" : True , "verify_exp" : True },
67+ )
68+ else :
6969 # It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
7070 # If any of them work, we'll use that payload
71- for k in keys :
71+ for k in get_latest_keys () :
7272 try :
73- payload = jwt .decode (jwt_info .raw_token_string , k .key , algorithms = ["RS256" ]) # type: ignore
73+ payload = jwt .decode ( # type: ignore
74+ jwt_info .raw_token_string ,
75+ k .key , # type: ignore
76+ algorithms = [decode_algo ],
77+ options = {"verify_signature" : True , "verify_exp" : True },
78+ )
7479 break
7580 except DecodeError :
7681 pass
7782
78- elif jwt_info .version >= 3 :
79- matching_key = client .get_matching_key_from_jwt (jwt_info .raw_token_string )
80- payload = jwt .decode ( # type: ignore
81- jwt_info .raw_token_string ,
82- matching_key .key , # type: ignore
83- algorithms = ["RS256" ],
84- options = {"verify_signature" : True , "verify_exp" : True },
85- )
86-
8783 if payload is None :
8884 raise DecodeError ("Could not decode the token" )
8985
@@ -110,7 +106,7 @@ def get_info_from_access_token(
110106 if anti_csrf_token is None and do_anti_csrf_check :
111107 raise Exception ("Access token does not contain the anti-csrf token" )
112108
113- assert isinstance (expiry_time , int )
109+ assert isinstance (expiry_time , ( float , int ) )
114110
115111 if expiry_time < get_timestamp_ms ():
116112 raise Exception ("Access token expired" )
@@ -137,8 +133,8 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
137133 if version >= 3 :
138134 if (
139135 not isinstance (payload .get ("sub" ), str )
140- or not isinstance (payload .get ("exp" ), int )
141- or not isinstance (payload .get ("iat" ), int )
136+ or not isinstance (payload .get ("exp" ), ( int , float ) )
137+ or not isinstance (payload .get ("iat" ), ( int , float ) )
142138 or not isinstance (payload .get ("sessionHandle" ), str )
143139 or not isinstance (payload .get ("refreshTokenHash1" ), str )
144140 ):
@@ -153,8 +149,8 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
153149 not isinstance (payload .get ("sessionHandle" ), str )
154150 or payload .get ("userData" ) is None
155151 or not isinstance (payload .get ("refreshTokenHash1" ), str )
156- or not isinstance (payload .get ("expiryTime" ), int )
157- or not isinstance (payload .get ("timeCreated" ), int )
152+ or not isinstance (payload .get ("expiryTime" ), ( float , int ) )
153+ or not isinstance (payload .get ("timeCreated" ), ( float , int ) )
158154 ):
159155 log_debug_message (
160156 "validateAccessTokenStructure: Access token is using version < 3"
0 commit comments