Skip to content

Commit da9b902

Browse files
committed
fixes a bunch of bugs
1 parent 568c913 commit da9b902

File tree

2 files changed

+250
-28
lines changed

2 files changed

+250
-28
lines changed

supertokens_python/querier.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,20 +239,17 @@ async def f(url: str, method: str) -> Response:
239239
value = headers[key]
240240
unique_key += f";{key}={value}"
241241

242-
if user_context is not None and "_default" in user_context:
242+
if user_context is not None:
243243
if (
244-
"global_cache_tag" in user_context["_default"]
245-
and user_context["_default"]["global_cache_tag"]
244+
user_context.get("_default", {}).get("global_cache_tag", -1)
246245
!= self.__global_cache_tag
247246
):
248247
self.invalidate_core_call_cache(user_context, False)
249248

250-
if (
251-
not Querier.__disable_cache
252-
and "core_cache_call" in user_context["_default"]
253-
and unique_key in user_context["_default"]["core_cache_call"]
254-
):
255-
return user_context["_default"]["core_cache_call"][unique_key]
249+
if not Querier.__disable_cache and unique_key in user_context.get(
250+
"_default", {}
251+
).get("core_call_cache", {}):
252+
return user_context["_default"]["core_call_cache"][unique_key]
256253

257254
if Querier.network_interceptor is not None:
258255
(
@@ -280,11 +277,11 @@ async def f(url: str, method: str) -> Response:
280277
):
281278
user_context["_default"] = {
282279
**user_context.get("_default", {}),
283-
"core_cache_call": {
284-
**user_context.get("_default", {}).get("core_cache_call", {}),
280+
"core_call_cache": {
281+
**user_context.get("_default", {}).get("core_call_cache", {}),
285282
unique_key: response,
286283
},
287-
"global_cache_key": self.__global_cache_tag,
284+
"global_cache_tag": self.__global_cache_tag,
288285
}
289286

290287
return response
@@ -298,8 +295,7 @@ async def send_post_request(
298295
user_context: Union[Dict[str, Any], None],
299296
test: bool = False,
300297
) -> Dict[str, Any]:
301-
if user_context is not None:
302-
self.invalidate_core_call_cache(user_context)
298+
self.invalidate_core_call_cache(user_context)
303299
if data is None:
304300
data = {}
305301

@@ -375,8 +371,7 @@ async def send_put_request(
375371
data: Union[Dict[str, Any], None],
376372
user_context: Union[Dict[str, Any], None],
377373
) -> Dict[str, Any]:
378-
if user_context is not None:
379-
self.invalidate_core_call_cache(user_context)
374+
self.invalidate_core_call_cache(user_context)
380375
if data is None:
381376
data = {}
382377

@@ -405,22 +400,17 @@ def invalidate_core_call_cache(
405400
upd_global_cache_tag_if_necessary: bool = True,
406401
):
407402
if user_context is None:
408-
return
403+
user_context = {}
409404

410405
if upd_global_cache_tag_if_necessary and (
411-
"_default" in user_context
412-
and "keep_cache_alive" in user_context["_default"]
413-
and user_context["_default"]["keep_cache_alive"] is not True
406+
user_context.get("_default", {}).get("keep_cache_alive", False) is not True
414407
):
415408
self.__global_cache_tag = get_timestamp_ms()
416409

417-
if "_default" in user_context:
418-
user_context["_default"] = {
419-
**user_context["_default"],
420-
"core_call_cache": {},
421-
}
422-
else:
423-
user_context["_default"] = {"core_call_cache": {}}
410+
user_context["_default"] = {
411+
**user_context.get("_default", {}),
412+
"core_call_cache": {},
413+
}
424414

425415
def get_all_core_urls_for_path(self, path: str) -> List[str]:
426416
normalized_path = NormalisedURLPath(path)

tests/test_querier.py

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
emailpassword,
1818
emailverification,
1919
dashboard,
20+
thirdparty,
2021
)
2122
from supertokens_python import InputAppInfo
22-
from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id
23+
from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id, sign_up
24+
from supertokens_python.recipe.thirdparty.asyncio import (
25+
get_user_by_id as tp_get_user_by_id,
26+
)
2327
import asyncio
2428
import respx
2529
import httpx
@@ -227,6 +231,7 @@ def intercept(
227231
session.init(),
228232
emailpassword.init(),
229233
dashboard.init(),
234+
thirdparty.init(),
230235
],
231236
) # type: ignore
232237
start_st()
@@ -241,3 +246,230 @@ def intercept(
241246
user = await get_user_by_id("random", user_context)
242247
assert user is None
243248
assert not called_core
249+
250+
user = await tp_get_user_by_id("random", user_context)
251+
252+
assert user is None
253+
assert called_core
254+
255+
called_core = False
256+
257+
user = await tp_get_user_by_id("random", user_context)
258+
assert user is None
259+
assert not called_core
260+
261+
user = await get_user_by_id("random", user_context)
262+
assert user is None
263+
assert not called_core
264+
265+
266+
async def test_caching_gets_clear_with_non_get():
267+
268+
called_core = False
269+
270+
def intercept(
271+
url: str,
272+
method: str,
273+
headers: Dict[str, Any],
274+
params: Optional[Dict[str, Any]],
275+
body: Optional[Dict[str, Any]],
276+
_: Optional[Dict[str, Any]],
277+
):
278+
nonlocal called_core
279+
called_core = True
280+
return url, method, headers, params, body
281+
282+
init(
283+
supertokens_config=SupertokensConfig(
284+
connection_uri="http://localhost:3567", network_interceptor=intercept
285+
),
286+
app_info=InputAppInfo(
287+
app_name="ST",
288+
api_domain="http://api.supertokens.io",
289+
website_domain="http://supertokens.io",
290+
api_base_path="/auth",
291+
),
292+
framework="fastapi",
293+
mode="asgi",
294+
recipe_list=[
295+
session.init(),
296+
emailpassword.init(),
297+
dashboard.init(),
298+
],
299+
) # type: ignore
300+
start_st()
301+
user_context: Dict[str, Any] = {}
302+
user = await get_user_by_id("random", user_context)
303+
304+
assert user is None
305+
assert called_core
306+
307+
await sign_up("public", "[email protected]", "abcd1234", user_context)
308+
309+
called_core = False
310+
311+
user = await get_user_by_id("random", user_context)
312+
assert user is None
313+
assert called_core
314+
315+
called_core = False
316+
317+
user = await get_user_by_id("random", user_context)
318+
assert user is None
319+
assert not called_core
320+
321+
322+
async def test_no_caching_if_disabled_by_user():
323+
324+
called_core = False
325+
326+
def intercept(
327+
url: str,
328+
method: str,
329+
headers: Dict[str, Any],
330+
params: Optional[Dict[str, Any]],
331+
body: Optional[Dict[str, Any]],
332+
_: Optional[Dict[str, Any]],
333+
):
334+
nonlocal called_core
335+
called_core = True
336+
return url, method, headers, params, body
337+
338+
init(
339+
supertokens_config=SupertokensConfig(
340+
connection_uri="http://localhost:3567",
341+
network_interceptor=intercept,
342+
disable_core_call_cache=True,
343+
),
344+
app_info=InputAppInfo(
345+
app_name="ST",
346+
api_domain="http://api.supertokens.io",
347+
website_domain="http://supertokens.io",
348+
api_base_path="/auth",
349+
),
350+
framework="fastapi",
351+
mode="asgi",
352+
recipe_list=[
353+
session.init(),
354+
emailpassword.init(),
355+
dashboard.init(),
356+
],
357+
) # type: ignore
358+
start_st()
359+
user_context: Dict[str, Any] = {}
360+
user = await get_user_by_id("random", user_context)
361+
362+
assert user is None
363+
assert called_core
364+
365+
called_core = False
366+
367+
user = await get_user_by_id("random", user_context)
368+
assert user is None
369+
assert called_core
370+
371+
372+
async def test_no_caching_if_headers_are_different():
373+
374+
called_core = False
375+
376+
def intercept(
377+
url: str,
378+
method: str,
379+
headers: Dict[str, Any],
380+
params: Optional[Dict[str, Any]],
381+
body: Optional[Dict[str, Any]],
382+
_: Optional[Dict[str, Any]],
383+
):
384+
nonlocal called_core
385+
called_core = True
386+
return url, method, headers, params, body
387+
388+
init(
389+
supertokens_config=SupertokensConfig(
390+
connection_uri="http://localhost:3567",
391+
network_interceptor=intercept,
392+
),
393+
app_info=InputAppInfo(
394+
app_name="ST",
395+
api_domain="http://api.supertokens.io",
396+
website_domain="http://supertokens.io",
397+
api_base_path="/auth",
398+
),
399+
framework="fastapi",
400+
mode="asgi",
401+
recipe_list=[
402+
session.init(),
403+
emailpassword.init(),
404+
dashboard.init(),
405+
thirdparty.init(),
406+
],
407+
) # type: ignore
408+
start_st()
409+
user_context: Dict[str, Any] = {}
410+
user = await get_user_by_id("random", user_context)
411+
412+
assert user is None
413+
assert called_core
414+
415+
called_core = False
416+
417+
user = await get_user_by_id("random", user_context)
418+
assert user is None
419+
assert not called_core
420+
421+
called_core = False
422+
423+
user = await tp_get_user_by_id("random", user_context)
424+
assert user is None
425+
assert called_core
426+
427+
428+
async def test_caching_gets_clear_when_query_without_user_context():
429+
430+
called_core = False
431+
432+
def intercept(
433+
url: str,
434+
method: str,
435+
headers: Dict[str, Any],
436+
params: Optional[Dict[str, Any]],
437+
body: Optional[Dict[str, Any]],
438+
_: Optional[Dict[str, Any]],
439+
):
440+
nonlocal called_core
441+
called_core = True
442+
return url, method, headers, params, body
443+
444+
init(
445+
supertokens_config=SupertokensConfig(
446+
connection_uri="http://localhost:3567", network_interceptor=intercept
447+
),
448+
app_info=InputAppInfo(
449+
app_name="ST",
450+
api_domain="http://api.supertokens.io",
451+
website_domain="http://supertokens.io",
452+
api_base_path="/auth",
453+
),
454+
framework="fastapi",
455+
mode="asgi",
456+
recipe_list=[
457+
session.init(),
458+
emailpassword.init(),
459+
dashboard.init(),
460+
],
461+
) # type: ignore
462+
start_st()
463+
user_context: Dict[str, Any] = {}
464+
user = await get_user_by_id("random", user_context)
465+
466+
assert user is None
467+
assert called_core
468+
469+
await sign_up("public", "[email protected]", "abcd1234")
470+
471+
called_core = False
472+
473+
user = await get_user_by_id("random", user_context)
474+
assert user is None
475+
assert called_core

0 commit comments

Comments
 (0)