@@ -33,7 +33,7 @@ class JWKSConfigType(TypedDict):
3333
3434JWKSConfig : JWKSConfigType = {
3535 "cache_max_age" : JWKCacheMaxAgeInMs ,
36- "request_timeout" : 5000 , # 5s
36+ "request_timeout" : 10000 , # 10s
3737}
3838
3939
@@ -56,22 +56,36 @@ def reset_jwks_cache():
5656 cached_keys = None
5757
5858
59- def get_cached_keys () -> Optional [CachedKeys ]:
60- with RWLockContext (mutex , read = True ):
61- if cached_keys is not None :
62- # This means that we have valid JWKs for the given core path
63- # We check if we need to refresh before returning
59+ def get_cached_keys () -> Optional [List [PyJWK ]]:
60+ if cached_keys is not None :
61+ # This means that we have valid JWKs for the given core path
62+ # We check if we need to refresh before returning
63+
64+ # This means that the value in cache is not expired, in this case we return the cached value
65+ # Note that this also means that the SDK will not try to query any other core (if there are multiple)
66+ # if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
67+ # from the cores again after the entry in the cache is expired
68+ if cached_keys .is_fresh ():
69+ if environ .get ("SUPERTOKENS_ENV" ) == "testing" :
70+ log_debug_message ("Returning JWKS from cache" )
71+ return cached_keys .keys
72+
73+ return None
74+
6475
65- # This means that the value in cache is not expired, in this case we return the cached value
66- # Note that this also means that the SDK will not try to query any other core (if there are multiple)
67- # if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
68- # from the cores again after the entry in the cache is expired
69- if cached_keys .is_fresh ():
70- if environ .get ("SUPERTOKENS_ENV" ) == "testing" :
71- log_debug_message ("Returning JWKS from cache" )
72- return cached_keys
76+ def find_matching_keys (
77+ keys : Optional [List [PyJWK ]], kid : Optional [str ]
78+ ) -> Optional [List [PyJWK ]]:
79+ if kid is None or keys is None :
80+ # return all keys since the token does not have a kid
81+ return keys
7382
74- return None
83+ # kid has been provided so filter the keys
84+ matching_keys = [key for key in keys if key .key_id == kid ] # type: ignore
85+ if len (matching_keys ) > 0 :
86+ return matching_keys
87+
88+ return None
7589
7690
7791def get_latest_keys (kid : Optional [str ] = None ) -> List [PyJWK ]:
@@ -80,15 +94,9 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
8094 if environ .get ("SUPERTOKENS_ENV" ) == "testing" :
8195 log_debug_message ("Called find_jwk_client" )
8296
83- keys_from_cache = get_cached_keys ()
84- if keys_from_cache is not None :
85- if kid is None :
86- # return all keys since the token does not have a kid
87- return keys_from_cache .keys
88-
89- # kid has been provided so filter the keys
90- matching_keys = [key for key in keys_from_cache .keys if key .key_id == kid ] # type: ignore
91- if len (matching_keys ) > 0 :
97+ with RWLockContext (mutex , read = True ):
98+ matching_keys = find_matching_keys (get_cached_keys (), kid )
99+ if matching_keys is not None :
92100 return matching_keys
93101 # otherwise unknown kid, will continue to reload the keys
94102
@@ -104,6 +112,12 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
104112 last_error : Exception = Exception ("No valid JWKS found" )
105113
106114 with RWLockContext (mutex , read = False ):
115+ # check again if the keys are in cache
116+ # because another thread might have fetched the keys while this one was waiting for the lock
117+ matching_keys = find_matching_keys (get_cached_keys (), kid )
118+ if matching_keys is not None :
119+ return matching_keys
120+
107121 for path in core_paths :
108122 if environ .get ("SUPERTOKENS_ENV" ) == "testing" :
109123 log_debug_message ("Attempting to fetch JWKS from path: %s" , path )
0 commit comments