@@ -41,16 +41,15 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
4141# under the License.
4242from __future__ import annotations
4343
44- from typing import Any, Dict, List, Optional, Union
44+ from typing import Any, Dict, Optional, Union
4545
4646import jwt
47- from jwt.exceptions import DecodeError, PyJWKClientError
47+ from jwt.exceptions import DecodeError
4848
4949from supertokens_python.logger import log_debug_message
5050from supertokens_python.utils import get_timestamp_ms
5151
5252from .exceptions import raise_try_refresh_token_exception
53- from .jwks import JWKClient, JWKSRequestError, PyJWK
5453from .jwt import ParsedJWTInfo
5554
5655
@@ -71,47 +70,44 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
7170 return None
7271
7372
73+ from supertokens_python.recipe.session.jwks import get_latest_keys
74+
75+
7476def get_info_from_access_token(
7577 jwt_info: ParsedJWTInfo,
76- jwk_clients: List[JWKClient],
7778 do_anti_csrf_check: bool,
7879):
7980 try:
8081 payload: Optional[Dict[str, Any]] = None
81- client: Optional[JWKClient] = None
82- keys: Optional[List[PyJWK]] = None
83-
84- # Get the keys from the first available client
85- for c in jwk_clients:
86- try:
87- client = c
88- keys = c.get_latest_keys()
89- break
90- except JWKSRequestError:
91- continue
92-
93- if keys is None or client is None:
94- raise PyJWKClientError("No key found")
95-
96- if jwt_info.version < 3:
82+ decode_algo = (
83+ jwt_info.parsed_header["alg"]
84+ if jwt_info.parsed_header is not None
85+ else "RS256"
86+ )
87+
88+ if jwt_info.version >= 3:
89+ matching_keys = get_latest_keys(jwt_info.kid)
90+ payload = jwt.decode( # type: ignore
91+ jwt_info.raw_token_string,
92+ matching_keys[0].key, # type: ignore
93+ algorithms=[decode_algo],
94+ options={"verify_signature": True, "verify_exp": True},
95+ )
96+ else:
9797 # It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
9898 # If any of them work, we'll use that payload
99- for k in keys :
99+ for k in get_latest_keys() :
100100 try:
101- payload = jwt.decode(jwt_info.raw_token_string, k.key, algorithms=["RS256"]) # type: ignore
101+ payload = jwt.decode( # type: ignore
102+ jwt_info.raw_token_string,
103+ k.key, # type: ignore
104+ algorithms=[decode_algo],
105+ options={"verify_signature": True, "verify_exp": True},
106+ )
102107 break
103108 except DecodeError:
104109 pass
105110
106- elif jwt_info.version >= 3:
107- matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
108- payload = jwt.decode( # type: ignore
109- jwt_info.raw_token_string,
110- matching_key.key, # type: ignore
111- algorithms=["RS256"],
112- options={"verify_signature": True, "verify_exp": True},
113- )
114-
115111 if payload is None:
116112 raise DecodeError("Could not decode the token")
117113
@@ -138,7 +134,7 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
138134 if anti_csrf_token is None and do_anti_csrf_check:
139135 raise Exception("Access token does not contain the anti-csrf token")
140136
141- assert isinstance(expiry_time, int)
137+ assert isinstance(expiry_time, (float, int) )
142138
143139 if expiry_time < get_timestamp_ms():
144140 raise Exception("Access token expired")
@@ -165,8 +161,8 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
165161 if version >= 3:
166162 if (
167163 not isinstance(payload.get("sub"), str)
168- or not isinstance(payload.get("exp"), int)
169- or not isinstance(payload.get("iat"), int)
164+ or not isinstance(payload.get("exp"), ( int, float) )
165+ or not isinstance(payload.get("iat"), ( int, float) )
170166 or not isinstance(payload.get("sessionHandle"), str)
171167 or not isinstance(payload.get("refreshTokenHash1"), str)
172168 ):
@@ -181,8 +177,8 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
181177 not isinstance(payload.get("sessionHandle"), str)
182178 or payload.get("userData") is None
183179 or not isinstance(payload.get("refreshTokenHash1"), str)
184- or not isinstance(payload.get("expiryTime"), int)
185- or not isinstance(payload.get("timeCreated"), int)
180+ or not isinstance(payload.get("expiryTime"), (float, int) )
181+ or not isinstance(payload.get("timeCreated"), (float, int) )
186182 ):
187183 log_debug_message(
188184 "validateAccessTokenStructure: Access token is using version < 3"
@@ -201,7 +197,7 @@ <h1 class="title">Module <code>supertokens_python.recipe.session.access_token</c
201197< h2 class ="section-title " id ="header-functions "> Functions</ h2 >
202198< dl >
203199< dt id ="supertokens_python.recipe.session.access_token.get_info_from_access_token "> < code class ="name flex ">
204- < span > def < span class ="ident "> get_info_from_access_token</ span > </ span > (< span > jwt_info: ParsedJWTInfo, jwk_clients: List[JWKClient], do_anti_csrf_check: bool)</ span >
200+ < span > def < span class ="ident "> get_info_from_access_token</ span > </ span > (< span > jwt_info: ParsedJWTInfo, do_anti_csrf_check: bool)</ span >
205201</ code > </ dt >
206202< dd >
207203< div class ="desc "> </ div >
@@ -211,45 +207,39 @@ <h2 class="section-title" id="header-functions">Functions</h2>
211207</ summary >
212208< pre > < code class ="python "> def get_info_from_access_token(
213209 jwt_info: ParsedJWTInfo,
214- jwk_clients: List[JWKClient],
215210 do_anti_csrf_check: bool,
216211):
217212 try:
218213 payload: Optional[Dict[str, Any]] = None
219- client: Optional[JWKClient] = None
220- keys: Optional[List[PyJWK]] = None
221-
222- # Get the keys from the first available client
223- for c in jwk_clients:
224- try:
225- client = c
226- keys = c.get_latest_keys()
227- break
228- except JWKSRequestError:
229- continue
230-
231- if keys is None or client is None:
232- raise PyJWKClientError("No key found")
233-
234- if jwt_info.version < 3:
214+ decode_algo = (
215+ jwt_info.parsed_header["alg"]
216+ if jwt_info.parsed_header is not None
217+ else "RS256"
218+ )
219+
220+ if jwt_info.version >= 3:
221+ matching_keys = get_latest_keys(jwt_info.kid)
222+ payload = jwt.decode( # type: ignore
223+ jwt_info.raw_token_string,
224+ matching_keys[0].key, # type: ignore
225+ algorithms=[decode_algo],
226+ options={"verify_signature": True, "verify_exp": True},
227+ )
228+ else:
235229 # It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
236230 # If any of them work, we'll use that payload
237- for k in keys :
231+ for k in get_latest_keys() :
238232 try:
239- payload = jwt.decode(jwt_info.raw_token_string, k.key, algorithms=["RS256"]) # type: ignore
233+ payload = jwt.decode( # type: ignore
234+ jwt_info.raw_token_string,
235+ k.key, # type: ignore
236+ algorithms=[decode_algo],
237+ options={"verify_signature": True, "verify_exp": True},
238+ )
240239 break
241240 except DecodeError:
242241 pass
243242
244- elif jwt_info.version >= 3:
245- matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
246- payload = jwt.decode( # type: ignore
247- jwt_info.raw_token_string,
248- matching_key.key, # type: ignore
249- algorithms=["RS256"],
250- options={"verify_signature": True, "verify_exp": True},
251- )
252-
253243 if payload is None:
254244 raise DecodeError("Could not decode the token")
255245
@@ -276,7 +266,7 @@ <h2 class="section-title" id="header-functions">Functions</h2>
276266 if anti_csrf_token is None and do_anti_csrf_check:
277267 raise Exception("Access token does not contain the anti-csrf token")
278268
279- assert isinstance(expiry_time, int)
269+ assert isinstance(expiry_time, (float, int) )
280270
281271 if expiry_time < get_timestamp_ms():
282272 raise Exception("Access token expired")
@@ -347,8 +337,8 @@ <h2 class="section-title" id="header-functions">Functions</h2>
347337 if version >= 3:
348338 if (
349339 not isinstance(payload.get("sub"), str)
350- or not isinstance(payload.get("exp"), int)
351- or not isinstance(payload.get("iat"), int)
340+ or not isinstance(payload.get("exp"), ( int, float) )
341+ or not isinstance(payload.get("iat"), ( int, float) )
352342 or not isinstance(payload.get("sessionHandle"), str)
353343 or not isinstance(payload.get("refreshTokenHash1"), str)
354344 ):
@@ -363,8 +353,8 @@ <h2 class="section-title" id="header-functions">Functions</h2>
363353 not isinstance(payload.get("sessionHandle"), str)
364354 or payload.get("userData") is None
365355 or not isinstance(payload.get("refreshTokenHash1"), str)
366- or not isinstance(payload.get("expiryTime"), int)
367- or not isinstance(payload.get("timeCreated"), int)
356+ or not isinstance(payload.get("expiryTime"), (float, int) )
357+ or not isinstance(payload.get("timeCreated"), (float, int) )
368358 ):
369359 log_debug_message(
370360 "validateAccessTokenStructure: Access token is using version < 3"
0 commit comments