3535
3636from typing import List , Set , Union
3737
38- from .exceptions import raise_general_exception
3938from .process_state import AllowedProcessStates , ProcessState
4039from .utils import find_max_version , is_4xx_error , is_5xx_error
4140from sniffio import AsyncLibraryNotFoundError
4241from supertokens_python .async_to_sync_wrapper import create_or_get_event_loop
42+ from supertokens_python .utils import get_timestamp_ms
4343
4444
4545class Querier :
@@ -68,10 +68,13 @@ class Querier:
6868 ],
6969 ]
7070 ] = None
71+ __global_cache_tag = get_timestamp_ms ()
72+ __disable_cache = False
7173
7274 def __init__ (self , hosts : List [Host ], rid_to_core : Union [None , str ] = None ):
7375 self .__hosts = hosts
7476 self .__rid_to_core = None
77+ self .__global_cache_tag = get_timestamp_ms ()
7578 if rid_to_core is not None :
7679 self .__rid_to_core = rid_to_core
7780
@@ -80,15 +83,15 @@ def reset():
8083 if ("SUPERTOKENS_ENV" not in environ ) or (
8184 environ ["SUPERTOKENS_ENV" ] != "testing"
8285 ):
83- raise_general_exception ("calling testing function in non testing env" )
86+ raise Exception ("calling testing function in non testing env" )
8487 Querier .__init_called = False
8588
8689 @staticmethod
8790 def get_hosts_alive_for_testing ():
8891 if ("SUPERTOKENS_ENV" not in environ ) or (
8992 environ ["SUPERTOKENS_ENV" ] != "testing"
9093 ):
91- raise_general_exception ("calling testing function in non testing env" )
94+ raise Exception ("calling testing function in non testing env" )
9295 return Querier .__hosts_alive_for_testing
9396
9497 async def api_request (
@@ -100,7 +103,7 @@ async def api_request(
100103 ** kwargs : Any ,
101104 ) -> Response :
102105 if attempts_remaining == 0 :
103- raise_general_exception ("Retry request failed" )
106+ raise Exception ("Retry request failed" )
104107
105108 try :
106109 async with AsyncClient () as client :
@@ -141,7 +144,7 @@ async def f(url: str, method: str) -> Response:
141144 api_version = find_max_version (cdi_supported_by_server , SUPPORTED_CDI_VERSIONS )
142145
143146 if api_version is None :
144- raise_general_exception (
147+ raise Exception (
145148 "The running SuperTokens core version is not compatible with this python "
146149 "SDK. Please visit https://supertokens.io/docs/community/compatibility-table "
147150 "to find the right versions"
@@ -152,7 +155,7 @@ async def f(url: str, method: str) -> Response:
152155
153156 @staticmethod
154157 def get_instance (rid_to_core : Union [str , None ] = None ):
155- if ( not Querier .__init_called ) or ( Querier . __hosts is None ) :
158+ if not Querier .__init_called :
156159 raise Exception (
157160 "Please call the supertokens.init function before using SuperTokens"
158161 )
@@ -181,6 +184,7 @@ def init(
181184 ],
182185 ]
183186 ] = None ,
187+ disable_cache : bool = False ,
184188 ):
185189 if not Querier .__init_called :
186190 Querier .__init_called = True
@@ -190,6 +194,7 @@ def init(
190194 Querier .__last_tried_index = 0
191195 Querier .__hosts_alive_for_testing = set ()
192196 Querier .network_interceptor = network_interceptor
197+ Querier .__disable_cache = disable_cache
193198
194199 async def __get_headers_with_api_version (self , path : NormalisedURLPath ):
195200 headers = {API_VERSION_HEADER : await self .get_api_version ()}
@@ -211,6 +216,41 @@ async def send_get_request(
211216 async def f (url : str , method : str ) -> Response :
212217 headers = await self .__get_headers_with_api_version (path )
213218 nonlocal params
219+
220+ assert params is not None
221+
222+ # Sort the keys for deterministic order
223+ sorted_keys = sorted (params .keys ())
224+ sorted_header_keys = sorted (headers .keys ())
225+
226+ # Start with the path as the unique key
227+ unique_key = path .get_as_string_dangerous ()
228+
229+ # Append sorted params to the unique key
230+ for key in sorted_keys :
231+ value = params [key ]
232+ unique_key += f";{ key } ={ value } "
233+
234+ # Append a separator for headers
235+ unique_key += ";hdrs"
236+
237+ # Append sorted headers to the unique key
238+ for key in sorted_header_keys :
239+ value = headers [key ]
240+ unique_key += f";{ key } ={ value } "
241+
242+ if user_context is not None :
243+ if (
244+ user_context .get ("_default" , {}).get ("global_cache_tag" , - 1 )
245+ != self .__global_cache_tag
246+ ):
247+ self .invalidate_core_call_cache (user_context , False )
248+
249+ if not Querier .__disable_cache and unique_key in user_context .get (
250+ "_default" , {}
251+ ).get ("core_call_cache" , {}):
252+ return user_context ["_default" ]["core_call_cache" ][unique_key ]
253+
214254 if Querier .network_interceptor is not None :
215255 (
216256 url ,
@@ -222,14 +262,30 @@ async def f(url: str, method: str) -> Response:
222262 url , method , headers , params , {}, user_context
223263 )
224264
225- return await self .api_request (
265+ response = await self .api_request (
226266 url ,
227267 method ,
228268 2 ,
229269 headers = headers ,
230270 params = params ,
231271 )
232272
273+ if (
274+ response .status_code == 200
275+ and not Querier .__disable_cache
276+ and user_context is not None
277+ ):
278+ user_context ["_default" ] = {
279+ ** user_context .get ("_default" , {}),
280+ "core_call_cache" : {
281+ ** user_context .get ("_default" , {}).get ("core_call_cache" , {}),
282+ unique_key : response ,
283+ },
284+ "global_cache_tag" : self .__global_cache_tag ,
285+ }
286+
287+ return response
288+
233289 return await self .__send_request_helper (path , "GET" , f , len (self .__hosts ))
234290
235291 async def send_post_request (
@@ -239,6 +295,7 @@ async def send_post_request(
239295 user_context : Union [Dict [str , Any ], None ],
240296 test : bool = False ,
241297 ) -> Dict [str , Any ]:
298+ self .invalidate_core_call_cache (user_context )
242299 if data is None :
243300 data = {}
244301
@@ -280,6 +337,8 @@ async def send_delete_request(
280337 params : Union [Dict [str , Any ], None ],
281338 user_context : Union [Dict [str , Any ], None ],
282339 ) -> Dict [str , Any ]:
340+ if user_context is not None :
341+ self .invalidate_core_call_cache (user_context )
283342 if params is None :
284343 params = {}
285344
@@ -312,6 +371,7 @@ async def send_put_request(
312371 data : Union [Dict [str , Any ], None ],
313372 user_context : Union [Dict [str , Any ], None ],
314373 ) -> Dict [str , Any ]:
374+ self .invalidate_core_call_cache (user_context )
315375 if data is None :
316376 data = {}
317377
@@ -334,10 +394,29 @@ async def f(url: str, method: str) -> Response:
334394
335395 return await self .__send_request_helper (path , "PUT" , f , len (self .__hosts ))
336396
337- def get_all_core_urls_for_path (self , path : str ) -> List [str ]:
338- if self .__hosts is None :
339- return []
397+ def invalidate_core_call_cache (
398+ self ,
399+ user_context : Union [Dict [str , Any ], None ],
400+ upd_global_cache_tag_if_necessary : bool = True ,
401+ ):
402+ if user_context is None :
403+ # this is done so that the code below runs as expected.
404+ # It will reset the __global_cache_tag if needed, and the
405+ # stuff we assign to the user_context will just be ignored (as expected)
406+ user_context = {}
407+
408+ if upd_global_cache_tag_if_necessary and (
409+ user_context .get ("_default" , {}).get ("keep_cache_alive" , False ) is not True
410+ ):
411+ # there can be race conditions here, but i think we can ignore them.
412+ self .__global_cache_tag = get_timestamp_ms ()
340413
414+ user_context ["_default" ] = {
415+ ** user_context .get ("_default" , {}),
416+ "core_call_cache" : {},
417+ }
418+
419+ def get_all_core_urls_for_path (self , path : str ) -> List [str ]:
341420 normalized_path = NormalisedURLPath (path )
342421
343422 result : List [str ] = []
@@ -362,7 +441,7 @@ async def __send_request_helper(
362441 retry_info_map : Optional [Dict [str , int ]] = None ,
363442 ) -> Dict [str , Any ]:
364443 if no_of_tries == 0 :
365- raise_general_exception ("No SuperTokens core available to query" )
444+ raise Exception ("No SuperTokens core available to query" )
366445
367446 try :
368447 current_host_domain = self .__hosts [
@@ -408,7 +487,7 @@ async def __send_request_helper(
408487 )
409488
410489 if is_4xx_error (response .status_code ) or is_5xx_error (response .status_code ): # type: ignore
411- raise_general_exception (
490+ raise Exception (
412491 "SuperTokens core threw an error for a "
413492 + method
414493 + " request to path: "
@@ -432,5 +511,3 @@ async def __send_request_helper(
432511 return await self .__send_request_helper (
433512 path , method , http_function , no_of_tries - 1 , retry_info_map
434513 )
435- except Exception as e :
436- raise_general_exception (e )
0 commit comments