|
1 | | -import json |
| 1 | +# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This software is licensed under the Apache License, Version 2.0 (the |
| 4 | +# "License") as published by the Apache Software Foundation. |
| 5 | +# |
| 6 | +# You may not use this file except in compliance with the License. You may |
| 7 | +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 11 | +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 12 | +# License for the specific language governing permissions and limitations |
| 13 | +# under the License. |
| 14 | + |
2 | 15 | import requests |
| 16 | +from os import environ |
3 | 17 | from typing import List, Optional |
| 18 | +from typing_extensions import TypedDict |
4 | 19 |
|
5 | 20 | from jwt import PyJWK, PyJWKSet |
6 | | -from jwt.api_jwt import decode_complete as decode_token # type: ignore |
7 | | - |
8 | | -from supertokens_python.utils import get_timestamp_ms |
9 | 21 |
|
10 | 22 | from .constants import JWKCacheMaxAgeInMs, JWKRequestCooldownInMs |
11 | 23 |
|
| 24 | +from supertokens_python.utils import RWMutex, RWLockContext, get_timestamp_ms |
| 25 | +from supertokens_python.querier import Querier |
12 | 26 | from supertokens_python.logger import log_debug_message |
13 | 27 |
|
14 | 28 |
|
15 | | -class JWKClient: |
16 | | - def __init__( |
17 | | - self, |
18 | | - uri: str, |
19 | | - cooldown_duration: int = JWKRequestCooldownInMs, |
20 | | - cache_max_age: int = JWKCacheMaxAgeInMs, |
21 | | - ): |
22 | | - """A client for retrieving JSON Web Key Sets (JWKS) from a given URI. |
23 | | -
|
24 | | - Args: |
25 | | - uri (str): The URI of the JWKS. |
26 | | - cooldown_duration (int, optional): The cooldown duration in ms. Defaults to 500 seconds. |
27 | | - cache_max_age (int, optional): The cache max age in ms. Defaults to 5 minutes. |
28 | | -
|
29 | | - Note: The JSON Web Key Set is fetched when no key matches the selection |
30 | | - process but only as frequently as the `self.cooldown_duration` option |
31 | | - allows to prevent abuse. The `self.cache_max_age` option is used to |
32 | | - determine how long the JWKS is cached for. |
33 | | -
|
34 | | - Whenever you make a call to `get_signing_key_from_jwt`, the JWKS |
35 | | - is fetched if it is older than `self.cache_max_age` ms unless |
36 | | - cooldown is active. |
37 | | - """ |
38 | | - self.uri = uri |
39 | | - self.cooldown_duration = cooldown_duration |
40 | | - self.cache_max_age = cache_max_age |
41 | | - self.timeout_sec = 5 |
42 | | - self.last_fetch_time: int = 0 |
43 | | - self.cached_jwks: Optional[PyJWKSet] = None |
44 | | - |
45 | | - def fetch(self): |
46 | | - try: |
47 | | - log_debug_message("Fetching jwk set from the configured uri") |
48 | | - with requests.get(self.uri, timeout=self.timeout_sec) as response: |
49 | | - response.raise_for_status() |
50 | | - self.cached_jwks = PyJWKSet.from_dict(json.load(response)) # type: ignore |
51 | | - self.last_fetch_time = get_timestamp_ms() |
52 | | - except requests.HTTPError: |
53 | | - raise JWKSRequestError("Failed to fetch jwk set from the configured uri") |
54 | | - |
55 | | - def is_cooling_down(self) -> bool: |
56 | | - return (self.last_fetch_time > 0) and ( |
57 | | - get_timestamp_ms() - self.last_fetch_time < self.cooldown_duration |
58 | | - ) |
| 29 | +class JWKSConfigType(TypedDict): |
| 30 | + cache_max_age: int |
| 31 | + refresh_rate_limit: int |
| 32 | + request_timeout: int |
| 33 | + |
| 34 | + |
| 35 | +JWKSConfig: JWKSConfigType = { |
| 36 | + "cache_max_age": JWKCacheMaxAgeInMs, |
| 37 | + "refresh_rate_limit": JWKRequestCooldownInMs, # FIXME: Not used |
| 38 | + "request_timeout": 5000, # 5s |
| 39 | +} |
59 | 40 |
|
60 | | - def is_fresh(self) -> bool: |
61 | | - return (self.last_fetch_time > 0) and ( |
62 | | - get_timestamp_ms() - self.last_fetch_time < self.cache_max_age |
| 41 | + |
| 42 | +class CachedKeys: |
| 43 | + def __init__(self, path: str, keys: List[PyJWK]): |
| 44 | + self.path = path |
| 45 | + self.keys = keys |
| 46 | + self.last_refresh_time = get_timestamp_ms() |
| 47 | + |
| 48 | + def is_fresh(self): |
| 49 | + return ( |
| 50 | + get_timestamp_ms() - self.last_refresh_time < JWKSConfig["cache_max_age"] |
63 | 51 | ) |
64 | 52 |
|
65 | | - def get_latest_keys(self) -> List[PyJWK]: |
66 | | - if self.cached_jwks is None or not self.is_fresh(): |
67 | | - self.fetch() |
68 | 53 |
|
69 | | - if self.cached_jwks is None: |
70 | | - raise JWKSRequestError("Failed to fetch the latest keys") |
| 54 | +cached_keys: Optional[CachedKeys] = None |
| 55 | +mutex = RWMutex() |
71 | 56 |
|
72 | | - all_keys: List[PyJWK] = self.cached_jwks.keys # type: ignore |
| 57 | +# only for testing purposes |
| 58 | +def reset_jwks_cache(): |
| 59 | + with RWLockContext(mutex, read=False): |
| 60 | + global cached_keys |
| 61 | + cached_keys = None |
73 | 62 |
|
74 | | - return all_keys |
75 | 63 |
|
76 | | - def get_matching_key_from_jwt(self, token: str) -> PyJWK: |
77 | | - header = decode_token(token, options={"verify_signature": False})["header"] |
78 | | - kid: str = header["kid"] # type: ignore |
| 64 | +def get_cached_keys() -> Optional[CachedKeys]: |
| 65 | + with RWLockContext(mutex, read=True): |
| 66 | + if cached_keys is not None: |
| 67 | + # This means that we have valid JWKs for the given core path |
| 68 | + # We check if we need to refresh before returning |
79 | 69 |
|
80 | | - if self.cached_jwks is None or not self.is_fresh(): |
81 | | - self.fetch() |
| 70 | + # This means that the value in cache is not expired, in this case we return the cached value |
| 71 | + # Note that this also means that the SDK will not try to query any other core (if there are multiple) |
| 72 | + # if it has a valid cache entry from one of the core URLs. It will only attempt to fetch |
| 73 | + # from the cores again after the entry in the cache is expired |
| 74 | + if cached_keys.is_fresh(): |
| 75 | + if environ.get("SUPERTOKENS_ENV") == "testing": |
| 76 | + log_debug_message("Returning JWKS from cache") |
| 77 | + return cached_keys |
82 | 78 |
|
83 | | - assert self.cached_jwks is not None |
| 79 | + return None |
84 | 80 |
|
85 | | - try: |
86 | | - return self.cached_jwks[kid] # type: ignore |
87 | | - except KeyError: |
88 | | - if not self.is_cooling_down(): |
89 | | - # One more attempt to fetch the latest keys |
90 | | - # and then try to find the key again. |
91 | | - self.fetch() |
92 | | - try: |
93 | | - return self.cached_jwks[kid] # type: ignore |
94 | | - except KeyError: |
95 | | - pass |
96 | | - except Exception: |
97 | | - raise JWKSKeyNotFoundError("No key found for the given kid") |
98 | 81 |
|
99 | | - raise JWKSKeyNotFoundError("No key found for the given kid") |
| 82 | +def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]: |
| 83 | + global cached_keys |
100 | 84 |
|
| 85 | + if environ.get("SUPERTOKENS_ENV") == "testing": |
| 86 | + log_debug_message("Called find_jwk_client") |
101 | 87 |
|
102 | | -class JWKSKeyNotFoundError(Exception): |
103 | | - pass |
| 88 | + keys_from_cache = get_cached_keys() |
| 89 | + if keys_from_cache is not None: |
| 90 | + if kid is None: |
| 91 | + # return all keys since the token does not have a kid |
| 92 | + return keys_from_cache.keys |
| 93 | + |
| 94 | + # kid has been provided so filter the keys |
| 95 | + matching_keys = [key for key in keys_from_cache.keys if key.key_id == kid] # type: ignore |
| 96 | + if len(matching_keys) > 0: |
| 97 | + return matching_keys |
| 98 | + # otherwise unknown kid, will continue to reload the keys |
| 99 | + |
| 100 | + core_paths = Querier.get_instance().get_all_core_urls_for_path( |
| 101 | + "./.well-known/jwks.json" |
| 102 | + ) |
| 103 | + |
| 104 | + if len(core_paths) == 0: |
| 105 | + raise Exception( |
| 106 | + "No SuperTokens core available to query. Please pass supertokens > connection_uri to the init function, or override all the functions of the recipe you are using." |
| 107 | + ) |
| 108 | + |
| 109 | + last_error: Exception = Exception("No valid JWKS found") |
| 110 | + |
| 111 | + with RWLockContext(mutex, read=False): |
| 112 | + for path in core_paths: |
| 113 | + if environ.get("SUPERTOKENS_ENV") == "testing": |
| 114 | + log_debug_message("Attempting to fetch JWKS from path: %s", path) |
| 115 | + |
| 116 | + cached_jwks: Optional[List[PyJWK]] = None |
| 117 | + try: |
| 118 | + log_debug_message("Fetching jwk set from the configured uri") |
| 119 | + with requests.get(path, timeout=JWKSConfig['request_timeout']/1000) as response: # 5 second timeout |
| 120 | + response.raise_for_status() |
| 121 | + cached_jwks = PyJWKSet.from_dict(response.json()).keys # type: ignore |
| 122 | + except Exception as e: |
| 123 | + last_error = e |
| 124 | + |
| 125 | + if cached_jwks is not None: # we found a valid JWKS |
| 126 | + cached_keys = CachedKeys(path, cached_jwks) |
| 127 | + log_debug_message("Returning JWKS from fetch") |
| 128 | + return cached_keys.keys |
| 129 | + |
| 130 | + raise last_error |
104 | 131 |
|
105 | 132 |
|
106 | 133 | class JWKSRequestError(Exception): |
|
0 commit comments