Skip to content

Commit f817c5e

Browse files
Merge pull request #508 from supertokens/adds-caching-to-querier
adds caching for querier
2 parents dc12b28 + a6f4738 commit f817c5e

File tree

7 files changed

+599
-17
lines changed

7 files changed

+599
-17
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.22.0] - 2024-06-05
12+
- Adds caching per API based on user context.
13+
14+
### Breaking change:
15+
- Changes general error in querier to normal python error.
16+
1117
## [0.21.0] - 2024-05-23
1218

1319
### Breaking change

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
setup(
8585
name="supertokens_python",
86-
version="0.21.0",
86+
version="0.22.0",
8787
author="SuperTokens",
8888
license="Apache 2.0",
8989
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["3.0"]
17-
VERSION = "0.21.0"
17+
VERSION = "0.22.0"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"

supertokens_python/querier.py

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535

3636
from typing import List, Set, Union
3737

38-
from .exceptions import raise_general_exception
3938
from .process_state import AllowedProcessStates, ProcessState
4039
from .utils import find_max_version, is_4xx_error, is_5xx_error
4140
from sniffio import AsyncLibraryNotFoundError
4241
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
42+
from supertokens_python.utils import get_timestamp_ms
4343

4444

4545
class Querier:
@@ -68,10 +68,13 @@ class Querier:
6868
],
6969
]
7070
] = None
71+
__global_cache_tag = get_timestamp_ms()
72+
__disable_cache = False
7173

7274
def __init__(self, hosts: List[Host], rid_to_core: Union[None, str] = None):
7375
self.__hosts = hosts
7476
self.__rid_to_core = None
77+
self.__global_cache_tag = get_timestamp_ms()
7578
if rid_to_core is not None:
7679
self.__rid_to_core = rid_to_core
7780

@@ -80,15 +83,15 @@ def reset():
8083
if ("SUPERTOKENS_ENV" not in environ) or (
8184
environ["SUPERTOKENS_ENV"] != "testing"
8285
):
83-
raise_general_exception("calling testing function in non testing env")
86+
raise Exception("calling testing function in non testing env")
8487
Querier.__init_called = False
8588

8689
@staticmethod
8790
def get_hosts_alive_for_testing():
8891
if ("SUPERTOKENS_ENV" not in environ) or (
8992
environ["SUPERTOKENS_ENV"] != "testing"
9093
):
91-
raise_general_exception("calling testing function in non testing env")
94+
raise Exception("calling testing function in non testing env")
9295
return Querier.__hosts_alive_for_testing
9396

9497
async def api_request(
@@ -100,7 +103,7 @@ async def api_request(
100103
**kwargs: Any,
101104
) -> Response:
102105
if attempts_remaining == 0:
103-
raise_general_exception("Retry request failed")
106+
raise Exception("Retry request failed")
104107

105108
try:
106109
async with AsyncClient() as client:
@@ -141,7 +144,7 @@ async def f(url: str, method: str) -> Response:
141144
api_version = find_max_version(cdi_supported_by_server, SUPPORTED_CDI_VERSIONS)
142145

143146
if api_version is None:
144-
raise_general_exception(
147+
raise Exception(
145148
"The running SuperTokens core version is not compatible with this python "
146149
"SDK. Please visit https://supertokens.io/docs/community/compatibility-table "
147150
"to find the right versions"
@@ -152,7 +155,7 @@ async def f(url: str, method: str) -> Response:
152155

153156
@staticmethod
154157
def get_instance(rid_to_core: Union[str, None] = None):
155-
if (not Querier.__init_called) or (Querier.__hosts is None):
158+
if not Querier.__init_called:
156159
raise Exception(
157160
"Please call the supertokens.init function before using SuperTokens"
158161
)
@@ -181,6 +184,7 @@ def init(
181184
],
182185
]
183186
] = None,
187+
disable_cache: bool = False,
184188
):
185189
if not Querier.__init_called:
186190
Querier.__init_called = True
@@ -190,6 +194,7 @@ def init(
190194
Querier.__last_tried_index = 0
191195
Querier.__hosts_alive_for_testing = set()
192196
Querier.network_interceptor = network_interceptor
197+
Querier.__disable_cache = disable_cache
193198

194199
async def __get_headers_with_api_version(self, path: NormalisedURLPath):
195200
headers = {API_VERSION_HEADER: await self.get_api_version()}
@@ -211,6 +216,41 @@ async def send_get_request(
211216
async def f(url: str, method: str) -> Response:
212217
headers = await self.__get_headers_with_api_version(path)
213218
nonlocal params
219+
220+
assert params is not None
221+
222+
# Sort the keys for deterministic order
223+
sorted_keys = sorted(params.keys())
224+
sorted_header_keys = sorted(headers.keys())
225+
226+
# Start with the path as the unique key
227+
unique_key = path.get_as_string_dangerous()
228+
229+
# Append sorted params to the unique key
230+
for key in sorted_keys:
231+
value = params[key]
232+
unique_key += f";{key}={value}"
233+
234+
# Append a separator for headers
235+
unique_key += ";hdrs"
236+
237+
# Append sorted headers to the unique key
238+
for key in sorted_header_keys:
239+
value = headers[key]
240+
unique_key += f";{key}={value}"
241+
242+
if user_context is not None:
243+
if (
244+
user_context.get("_default", {}).get("global_cache_tag", -1)
245+
!= self.__global_cache_tag
246+
):
247+
self.invalidate_core_call_cache(user_context, False)
248+
249+
if not Querier.__disable_cache and unique_key in user_context.get(
250+
"_default", {}
251+
).get("core_call_cache", {}):
252+
return user_context["_default"]["core_call_cache"][unique_key]
253+
214254
if Querier.network_interceptor is not None:
215255
(
216256
url,
@@ -222,14 +262,30 @@ async def f(url: str, method: str) -> Response:
222262
url, method, headers, params, {}, user_context
223263
)
224264

225-
return await self.api_request(
265+
response = await self.api_request(
226266
url,
227267
method,
228268
2,
229269
headers=headers,
230270
params=params,
231271
)
232272

273+
if (
274+
response.status_code == 200
275+
and not Querier.__disable_cache
276+
and user_context is not None
277+
):
278+
user_context["_default"] = {
279+
**user_context.get("_default", {}),
280+
"core_call_cache": {
281+
**user_context.get("_default", {}).get("core_call_cache", {}),
282+
unique_key: response,
283+
},
284+
"global_cache_tag": self.__global_cache_tag,
285+
}
286+
287+
return response
288+
233289
return await self.__send_request_helper(path, "GET", f, len(self.__hosts))
234290

235291
async def send_post_request(
@@ -239,6 +295,7 @@ async def send_post_request(
239295
user_context: Union[Dict[str, Any], None],
240296
test: bool = False,
241297
) -> Dict[str, Any]:
298+
self.invalidate_core_call_cache(user_context)
242299
if data is None:
243300
data = {}
244301

@@ -280,6 +337,8 @@ async def send_delete_request(
280337
params: Union[Dict[str, Any], None],
281338
user_context: Union[Dict[str, Any], None],
282339
) -> Dict[str, Any]:
340+
if user_context is not None:
341+
self.invalidate_core_call_cache(user_context)
283342
if params is None:
284343
params = {}
285344

@@ -312,6 +371,7 @@ async def send_put_request(
312371
data: Union[Dict[str, Any], None],
313372
user_context: Union[Dict[str, Any], None],
314373
) -> Dict[str, Any]:
374+
self.invalidate_core_call_cache(user_context)
315375
if data is None:
316376
data = {}
317377

@@ -334,10 +394,29 @@ async def f(url: str, method: str) -> Response:
334394

335395
return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))
336396

337-
def get_all_core_urls_for_path(self, path: str) -> List[str]:
338-
if self.__hosts is None:
339-
return []
397+
def invalidate_core_call_cache(
398+
self,
399+
user_context: Union[Dict[str, Any], None],
400+
upd_global_cache_tag_if_necessary: bool = True,
401+
):
402+
if user_context is None:
403+
# this is done so that the code below runs as expected.
404+
# It will reset the __global_cache_tag if needed, and the
405+
# stuff we assign to the user_context will just be ignored (as expected)
406+
user_context = {}
407+
408+
if upd_global_cache_tag_if_necessary and (
409+
user_context.get("_default", {}).get("keep_cache_alive", False) is not True
410+
):
411+
# there can be race conditions here, but i think we can ignore them.
412+
self.__global_cache_tag = get_timestamp_ms()
340413

414+
user_context["_default"] = {
415+
**user_context.get("_default", {}),
416+
"core_call_cache": {},
417+
}
418+
419+
def get_all_core_urls_for_path(self, path: str) -> List[str]:
341420
normalized_path = NormalisedURLPath(path)
342421

343422
result: List[str] = []
@@ -362,7 +441,7 @@ async def __send_request_helper(
362441
retry_info_map: Optional[Dict[str, int]] = None,
363442
) -> Dict[str, Any]:
364443
if no_of_tries == 0:
365-
raise_general_exception("No SuperTokens core available to query")
444+
raise Exception("No SuperTokens core available to query")
366445

367446
try:
368447
current_host_domain = self.__hosts[
@@ -408,7 +487,7 @@ async def __send_request_helper(
408487
)
409488

410489
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
411-
raise_general_exception(
490+
raise Exception(
412491
"SuperTokens core threw an error for a "
413492
+ method
414493
+ " request to path: "
@@ -432,5 +511,3 @@ async def __send_request_helper(
432511
return await self.__send_request_helper(
433512
path, method, http_function, no_of_tries - 1, retry_info_map
434513
)
435-
except Exception as e:
436-
raise_general_exception(e)

supertokens_python/supertokens.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def __init__(
8686
],
8787
]
8888
] = None,
89+
disable_core_call_cache: bool = False,
8990
): # We keep this = None here because this is directly used by the user.
9091
self.connection_uri = connection_uri
9192
self.api_key = api_key
9293
self.network_interceptor = network_interceptor
94+
self.disable_core_call_cache = disable_core_call_cache
9395

9496

9597
class Host:
@@ -243,7 +245,10 @@ def __init__(
243245
)
244246
)
245247
Querier.init(
246-
hosts, supertokens_config.api_key, supertokens_config.network_interceptor
248+
hosts,
249+
supertokens_config.api_key,
250+
supertokens_config.network_interceptor,
251+
supertokens_config.disable_core_call_cache,
247252
)
248253

249254
if len(recipe_list) == 0:

supertokens_python/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def set_request_in_user_context_if_not_defined(
250250

251251
if isinstance(user_context["_default"], dict):
252252
user_context["_default"]["request"] = request
253+
user_context["_default"]["keep_cache_alive"] = True
253254

254255
return user_context
255256

0 commit comments

Comments
 (0)