Skip to content

Commit 8c871c3

Browse files
committed
test fixes
1 parent dc8c897 commit 8c871c3

File tree

5 files changed

+156
-23
lines changed

5 files changed

+156
-23
lines changed

tests/emailpassword/test_emailverify.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
from supertokens_python.recipe.session.constants import ANTI_CSRF_HEADER_KEY
5757
from supertokens_python.utils import (
5858
is_version_gte,
59-
set_request_in_user_context_if_not_defined,
6059
)
6160
from tests.utils import (
6261
TEST_ACCESS_TOKEN_MAX_AGE_CONFIG_KEY,
@@ -1326,19 +1325,14 @@ async def send_email(
13261325
nonlocal email_verify_link
13271326
email_verify_link = template_vars.email_verify_link
13281327

1329-
def get_origin(
1330-
req: Optional[BaseRequest], user_context: Optional[Dict[str, Any]]
1331-
) -> str:
1332-
if req is not None:
1333-
set_request_in_user_context_if_not_defined(user_context, req)
1334-
return user_context["url"] # type: ignore
1328+
def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str:
1329+
return user_context["url"]
13351330

13361331
init(
13371332
supertokens_config=SupertokensConfig("http://localhost:3567"),
13381333
app_info=InputAppInfo(
13391334
app_name="SuperTokens Demo",
13401335
api_domain="http://api.supertokens.io",
1341-
website_domain=None,
13421336
origin=get_origin,
13431337
api_base_path="/auth",
13441338
),
@@ -1354,11 +1348,6 @@ def get_origin(
13541348
)
13551349
start_st()
13561350

1357-
version = await Querier.get_instance().get_api_version()
1358-
if not is_version_gte(version, "2.9"):
1359-
# If the version less than 2.9, the recipe doesn't exist. So skip the test
1360-
skip()
1361-
13621351
response_1 = sign_up_request(driver_config_client, "[email protected]", "testPass123")
13631352
assert response_1.status_code == 200
13641353
dict_response = json.loads(response_1.text)

tests/emailpassword/test_passwordreset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,10 @@ async def test_reset_password_link_uses_correct_origin(
409409
password_reset_url = ""
410410

411411
def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
412-
if req is not None and req.get_header("origin") is not None:
413-
return req.get_header("origin") # type: ignore
412+
if req is not None:
413+
value = req.get_header("origin")
414+
if value is not None:
415+
return value
414416
return "localhost:3000"
415417

416418
class CustomEmailService(

tests/passwordless/test_emaildelivery.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,14 +540,16 @@ async def send_email_override(
540540

541541

542542
@mark.asyncio
543-
async def test_reset_password_link_uses_correct_origin(
543+
async def test_magic_link_uses_correct_origin(
544544
driver_config_client: TestClient,
545545
):
546546
login_url = ""
547547

548-
def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
549-
if req is not None and req.get_header("origin") is not None:
550-
return req.get_header("origin") # type: ignore
548+
def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str:
549+
if req is not None:
550+
value = req.get_header("origin")
551+
if value is not None:
552+
return value
551553
return "localhost:3000"
552554

553555
class CustomEmailService(

tests/test_config.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# under the License.
1414
from pytest import mark
1515
from unittest.mock import MagicMock
16-
from supertokens_python import InputAppInfo, SupertokensConfig, init
16+
from supertokens_python import InputAppInfo, SupertokensConfig, init, Supertokens
1717
from supertokens_python.normalised_url_domain import NormalisedURLDomain
1818
from supertokens_python.normalised_url_path import NormalisedURLPath
1919
from supertokens_python.recipe import session
2020
from supertokens_python.recipe.session import SessionRecipe
2121
from supertokens_python.recipe.session.asyncio import create_new_session
22+
from typing import Optional, Dict, Any
23+
from supertokens_python.framework import BaseRequest
2224

2325
from tests.utils import clean_st, reset, setup_st, start_st
2426

@@ -814,3 +816,139 @@ async def test_cookie_samesite_with_ec2_public_url():
814816
assert SessionRecipe.get_instance().config.cookie_domain is None
815817
assert SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "lax"
816818
assert SessionRecipe.get_instance().config.cookie_secure is False
819+
820+
821+
@mark.asyncio
822+
async def test_samesite_explicit_config():
823+
init(
824+
supertokens_config=SupertokensConfig("http://localhost:3567"),
825+
app_info=InputAppInfo(
826+
app_name="SuperTokens Demo",
827+
origin="http://localhost:3000",
828+
api_domain="http://localhost:3001",
829+
),
830+
framework="fastapi",
831+
recipe_list=[
832+
session.init(
833+
cookie_same_site="strict",
834+
)
835+
],
836+
)
837+
assert (
838+
SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "strict"
839+
)
840+
841+
842+
@mark.asyncio
843+
async def test_that_exception_is_thrown_if_website_domain_and_origin_are_not_passed():
844+
try:
845+
init(
846+
supertokens_config=SupertokensConfig("http://localhost:3567"),
847+
app_info=InputAppInfo(
848+
app_name="SuperTokens Demo",
849+
api_domain="http://localhost:3001",
850+
),
851+
framework="fastapi",
852+
recipe_list=[session.init()],
853+
)
854+
except Exception as e:
855+
assert str(e) == "Please provide at least one of website_domain or origin"
856+
else:
857+
assert False, "Exception not thrown"
858+
859+
860+
@mark.asyncio
861+
async def test_that_init_works_fine_when_using_origin_string():
862+
init(
863+
supertokens_config=SupertokensConfig("http://localhost:3567"),
864+
app_info=InputAppInfo(
865+
app_name="SuperTokens Demo",
866+
api_domain="http://localhost:3001",
867+
origin="localhost:3000",
868+
),
869+
framework="fastapi",
870+
recipe_list=[session.init()],
871+
)
872+
873+
assert (
874+
Supertokens.get_instance()
875+
.app_info.get_origin(None, {})
876+
.get_as_string_dangerous()
877+
== "http://localhost:3000"
878+
)
879+
880+
881+
@mark.asyncio
882+
async def test_that_init_works_fine_when_using_website_domain_string():
883+
init(
884+
supertokens_config=SupertokensConfig("http://localhost:3567"),
885+
app_info=InputAppInfo(
886+
app_name="SuperTokens Demo",
887+
api_domain="http://localhost:3001",
888+
website_domain="localhost:3000",
889+
),
890+
framework="fastapi",
891+
recipe_list=[session.init()],
892+
)
893+
894+
assert (
895+
Supertokens.get_instance()
896+
.app_info.get_origin(None, {})
897+
.get_as_string_dangerous()
898+
== "http://localhost:3000"
899+
)
900+
901+
902+
@mark.asyncio
903+
async def test_that_init_works_fine_when_using_origin_function():
904+
def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str:
905+
if "input" in user_context:
906+
return user_context["input"]
907+
return "localhost:3000"
908+
909+
init(
910+
supertokens_config=SupertokensConfig("http://localhost:3567"),
911+
app_info=InputAppInfo(
912+
app_name="SuperTokens Demo",
913+
api_domain="http://localhost:3001",
914+
origin=get_origin,
915+
),
916+
framework="fastapi",
917+
recipe_list=[session.init()],
918+
)
919+
920+
assert (
921+
Supertokens.get_instance()
922+
.app_info.get_origin(None, {"input": "localhost:1000"})
923+
.get_as_string_dangerous()
924+
== "http://localhost:1000"
925+
)
926+
927+
assert (
928+
Supertokens.get_instance()
929+
.app_info.get_origin(None, {})
930+
.get_as_string_dangerous()
931+
== "http://localhost:3000"
932+
)
933+
934+
935+
@mark.asyncio
936+
async def test_that_init_chooses_origin_over_website_domain():
937+
init(
938+
supertokens_config=SupertokensConfig("http://localhost:3567"),
939+
app_info=InputAppInfo(
940+
app_name="SuperTokens Demo",
941+
api_domain="http://localhost:3001",
942+
website_domain="localhost:3000",
943+
origin="supertokens.io",
944+
),
945+
framework="fastapi",
946+
recipe_list=[session.init()],
947+
)
948+
949+
assert (
950+
Supertokens.get_instance()
951+
.app_info.get_origin(None, {})
952+
.get_as_string_dangerous()
953+
== "https://supertokens.io"
954+
)

tests/test_session.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -828,9 +828,11 @@ async def test_expose_access_token_to_frontend_in_cookie_based_auth(
828828
async def test_token_transfer_method_works_when_using_origin_function(
829829
driver_config_client: TestClient,
830830
):
831-
def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
832-
if req is not None and req.get_header("origin") is not None:
833-
return req.get_header("origin") # type: ignore
831+
def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str:
832+
if req is not None:
833+
value = req.get_header("origin")
834+
if value is not None:
835+
return value
834836
return "localhost:3000"
835837

836838
def token_transfer_method(req: BaseRequest, _: bool, __: Dict[str, Any]):

0 commit comments

Comments
 (0)