Skip to content

Commit fee89fb

Browse files
Merge pull request #381 from supertokens/fix/session-mt
feat: Session recipe multitenancy changes
2 parents 96cdec8 + e7fccfd commit fee89fb

File tree

16 files changed

+160
-42
lines changed

16 files changed

+160
-42
lines changed

supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ async def handle_sessions_get(
2323
if user_id is None:
2424
raise_bad_input_exception("Missing required parameter 'userId'")
2525

26-
session_handles = await get_all_session_handles_for_user(user_id, user_context)
26+
# Passing tenant id as None sets fetch_across_all_tenants to True
27+
# which is what we want here.
28+
session_handles = await get_all_session_handles_for_user(
29+
user_id, None, user_context
30+
)
2731
sessions: List[Optional[SessionInfo]] = [None for _ in session_handles]
2832

2933
async def call_(i: int, session_handle: str):

supertokens_python/recipe/emailpassword/api/implementation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ async def sign_in_post(
183183
user.user_id,
184184
access_token_payload={},
185185
session_data_in_database={},
186+
tenant_id=tenant_id,
186187
user_context=user_context,
187188
)
188189
return SignInPostOkResult(user, session)
@@ -223,6 +224,7 @@ async def sign_up_post(
223224
user.user_id,
224225
access_token_payload={},
225226
session_data_in_database={},
227+
tenant_id=tenant_id,
226228
user_context=user_context,
227229
)
228230
return SignUpPostOkResult(user, session)

supertokens_python/recipe/emailverification/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ async def generate_email_verify_token_post(
412412
email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id(
413413
user_id, user_context
414414
)
415-
tenant_id = session.get_access_token_payload()["tId"]
415+
tenant_id = session.get_tenant_id()
416416

417417
if isinstance(email_info, EmailDoesNotExistError):
418418
log_debug_message(

supertokens_python/recipe/passwordless/api/implementation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ async def consume_code_post(
307307
user.user_id,
308308
{},
309309
{},
310+
tenant_id,
310311
user_context=user_context,
311312
)
312313

supertokens_python/recipe/session/access_token.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .exceptions import raise_try_refresh_token_exception
2525
from .jwt import ParsedJWTInfo
2626

27+
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
28+
2729

2830
def sanitize_string(s: Any) -> Union[str, None]:
2931
if s == "":
@@ -102,6 +104,10 @@ def get_info_from_access_token(
102104
payload.get("parentRefreshTokenHash1")
103105
)
104106
anti_csrf_token = sanitize_string(payload.get("antiCsrfToken"))
107+
tenant_id = DEFAULT_TENANT_ID
108+
109+
if jwt_info.version >= 4:
110+
tenant_id = sanitize_string(payload.get("tId"))
105111

106112
if anti_csrf_token is None and do_anti_csrf_check:
107113
raise Exception("Access token does not contain the anti-csrf token")
@@ -120,6 +126,7 @@ def get_info_from_access_token(
120126
"antiCsrfToken": anti_csrf_token,
121127
"expiryTime": expiry_time,
122128
"timeCreated": time_created,
129+
"tenantId": tenant_id,
123130
}
124131
except Exception as e:
125132
log_debug_message(
@@ -145,6 +152,13 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
145152
raise Exception(
146153
"Access token does not contain all the information. Maybe the structure has changed?"
147154
)
155+
156+
if version >= 4:
157+
if not isinstance(payload.get("tId"), str):
158+
raise Exception(
159+
"Access token does not contain all the information. Maybe the structure has changed?"
160+
)
161+
148162
elif (
149163
not isinstance(payload.get("sessionHandle"), str)
150164
or payload.get("userData") is None

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
)
4444
from ..utils import get_required_claim_validators
4545

46+
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
47+
4648
_T = TypeVar("_T")
4749

4850

@@ -51,6 +53,7 @@ async def create_new_session(
5153
user_id: str,
5254
access_token_payload: Union[Dict[str, Any], None] = None,
5355
session_data_in_database: Union[Dict[str, Any], None] = None,
56+
tenant_id: Optional[str] = None,
5457
user_context: Union[None, Dict[str, Any]] = None,
5558
) -> SessionContainer:
5659
if user_context is None:
@@ -73,6 +76,7 @@ async def create_new_session(
7376
config,
7477
app_info,
7578
session_data_in_database,
79+
tenant_id or DEFAULT_TENANT_ID,
7680
)
7781

7882

@@ -81,6 +85,7 @@ async def create_new_session_without_request_response(
8185
access_token_payload: Union[Dict[str, Any], None] = None,
8286
session_data_in_database: Union[Dict[str, Any], None] = None,
8387
disable_anti_csrf: bool = False,
88+
tenant_id: Optional[str] = None,
8489
user_context: Union[None, Dict[str, Any]] = None,
8590
) -> SessionContainer:
8691
if user_context is None:
@@ -102,15 +107,17 @@ async def create_new_session_without_request_response(
102107
final_access_token_payload = {**access_token_payload, "iss": issuer}
103108

104109
for claim in claims_added_by_other_recipes:
105-
# TODO: Pass tenant id
106-
update = await claim.build(user_id, "pass-tenant-id", user_context)
110+
update = await claim.build(
111+
user_id, tenant_id or DEFAULT_TENANT_ID, user_context
112+
)
107113
final_access_token_payload = {**final_access_token_payload, **update}
108114

109115
return await SessionRecipe.get_instance().recipe_implementation.create_new_session(
110116
user_id,
111117
final_access_token_payload,
112118
session_data_in_database,
113119
disable_anti_csrf,
120+
tenant_id or DEFAULT_TENANT_ID,
114121
user_context=user_context,
115122
)
116123

@@ -421,22 +428,26 @@ async def revoke_session(
421428

422429

423430
async def revoke_all_sessions_for_user(
424-
user_id: str, user_context: Union[None, Dict[str, Any]] = None
431+
user_id: str,
432+
tenant_id: Optional[str] = None,
433+
user_context: Union[None, Dict[str, Any]] = None,
425434
) -> List[str]:
426435
if user_context is None:
427436
user_context = {}
428437
return await SessionRecipe.get_instance().recipe_implementation.revoke_all_sessions_for_user(
429-
user_id, user_context
438+
user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context
430439
)
431440

432441

433442
async def get_all_session_handles_for_user(
434-
user_id: str, user_context: Union[None, Dict[str, Any]] = None
443+
user_id: str,
444+
tenant_id: Optional[str] = None,
445+
user_context: Union[None, Dict[str, Any]] = None,
435446
) -> List[str]:
436447
if user_context is None:
437448
user_context = {}
438449
return await SessionRecipe.get_instance().recipe_implementation.get_all_session_handles_for_user(
439-
user_id, user_context
450+
user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context
440451
)
441452

442453

supertokens_python/recipe/session/interfaces.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,17 @@
4141

4242

4343
class SessionObj:
44-
def __init__(self, handle: str, user_id: str, user_data_in_jwt: Dict[str, Any]):
44+
def __init__(
45+
self,
46+
handle: str,
47+
user_id: str,
48+
user_data_in_jwt: Dict[str, Any],
49+
tenant_id: str,
50+
):
4551
self.handle = handle
4652
self.user_id = user_id
4753
self.user_data_in_jwt = user_data_in_jwt
54+
self.tenant_id = tenant_id
4855

4956

5057
class AccessTokenObj:
@@ -69,15 +76,17 @@ def __init__(
6976
expiry: int,
7077
custom_claims_in_access_token_payload: Dict[str, Any],
7178
time_created: int,
79+
tenant_id: str,
7280
):
73-
self.session_handle: str = session_handle
74-
self.user_id: str = user_id
75-
self.session_data_in_database: Dict[str, Any] = session_data_in_database
76-
self.expiry: int = expiry
77-
self.custom_claims_in_access_token_payload: Dict[
78-
str, Any
79-
] = custom_claims_in_access_token_payload
80-
self.time_created: int = time_created
81+
self.session_handle = session_handle
82+
self.user_id = user_id
83+
self.session_data_in_database = session_data_in_database
84+
self.expiry = expiry
85+
self.custom_claims_in_access_token_payload = (
86+
custom_claims_in_access_token_payload
87+
)
88+
self.time_created = time_created
89+
self.tenant_id = tenant_id
8190

8291

8392
class ReqResInfo:
@@ -137,6 +146,7 @@ async def create_new_session(
137146
access_token_payload: Optional[Dict[str, Any]],
138147
session_data_in_database: Optional[Dict[str, Any]],
139148
disable_anti_csrf: Optional[bool],
149+
tenant_id: str,
140150
user_context: Dict[str, Any],
141151
) -> SessionContainer:
142152
pass
@@ -206,13 +216,21 @@ async def revoke_session(
206216

207217
@abstractmethod
208218
async def revoke_all_sessions_for_user(
209-
self, user_id: str, user_context: Dict[str, Any]
219+
self,
220+
user_id: str,
221+
tenant_id: str,
222+
revoke_across_all_tenants: bool,
223+
user_context: Dict[str, Any],
210224
) -> List[str]:
211225
pass
212226

213227
@abstractmethod
214228
async def get_all_session_handles_for_user(
215-
self, user_id: str, user_context: Dict[str, Any]
229+
self,
230+
user_id: str,
231+
tenant_id: str,
232+
fetch_across_all_tenants: bool,
233+
user_context: Dict[str, Any],
216234
) -> List[str]:
217235
pass
218236

@@ -383,6 +401,7 @@ def __init__(
383401
user_data_in_access_token: Optional[Dict[str, Any]],
384402
req_res_info: Optional[ReqResInfo],
385403
access_token_updated: bool,
404+
tenant_id: str,
386405
):
387406
self.recipe_implementation = recipe_implementation
388407
self.config = config
@@ -395,6 +414,7 @@ def __init__(
395414
self.user_data_in_access_token = user_data_in_access_token
396415
self.req_res_info: Optional[ReqResInfo] = req_res_info
397416
self.access_token_updated = access_token_updated
417+
self.tenant_id = tenant_id
398418

399419
self.response_mutators: List[ResponseMutator] = []
400420

@@ -436,6 +456,10 @@ async def merge_into_access_token_payload(
436456
def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
437457
pass
438458

459+
@abstractmethod
460+
def get_tenant_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
461+
pass
462+
439463
@abstractmethod
440464
def get_access_token_payload(
441465
self, user_context: Optional[Dict[str, Any]] = None

supertokens_python/recipe/session/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_apis_handled(self) -> List[APIHandled]:
183183
async def handle_api_request(
184184
self,
185185
request_id: str,
186-
tenant_id: Optional[str],
186+
tenant_id: str,
187187
request: BaseRequest,
188188
path: NormalisedURLPath,
189189
method: str,

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def create_new_session(
6464
access_token_payload: Optional[Dict[str, Any]],
6565
session_data_in_database: Optional[Dict[str, Any]],
6666
disable_anti_csrf: Optional[bool],
67+
tenant_id: str,
6768
user_context: Dict[str, Any],
6869
) -> SessionContainer:
6970
log_debug_message("createNewSession: Started")
@@ -74,6 +75,7 @@ async def create_new_session(
7475
disable_anti_csrf is True,
7576
access_token_payload,
7677
session_data_in_database,
78+
tenant_id,
7779
)
7880
log_debug_message("createNewSession: Finished")
7981

@@ -95,6 +97,7 @@ async def create_new_session(
9597
payload,
9698
None,
9799
True,
100+
tenant_id,
98101
)
99102

100103
return new_session
@@ -262,6 +265,7 @@ async def get_session(
262265
payload,
263266
None,
264267
access_token_updated,
268+
response.session.tenant_id,
265269
)
266270

267271
return session
@@ -312,6 +316,7 @@ async def refresh_session(
312316
user_data_in_access_token=payload,
313317
req_res_info=None,
314318
access_token_updated=True,
319+
tenant_id=payload["tId"],
315320
)
316321

317322
return session
@@ -322,14 +327,26 @@ async def revoke_session(
322327
return await session_functions.revoke_session(self, session_handle)
323328

324329
async def revoke_all_sessions_for_user(
325-
self, user_id: str, user_context: Dict[str, Any]
330+
self,
331+
user_id: str,
332+
tenant_id: Optional[str],
333+
revoke_across_all_tenants: bool,
334+
user_context: Dict[str, Any],
326335
) -> List[str]:
327-
return await session_functions.revoke_all_sessions_for_user(self, user_id)
336+
return await session_functions.revoke_all_sessions_for_user(
337+
self, user_id, tenant_id, revoke_across_all_tenants
338+
)
328339

329340
async def get_all_session_handles_for_user(
330-
self, user_id: str, user_context: Dict[str, Any]
341+
self,
342+
user_id: str,
343+
tenant_id: Optional[str],
344+
fetch_across_all_tenants: bool,
345+
user_context: Dict[str, Any],
331346
) -> List[str]:
332-
return await session_functions.get_all_session_handles_for_user(self, user_id)
347+
return await session_functions.get_all_session_handles_for_user(
348+
self, user_id, tenant_id, fetch_across_all_tenants
349+
)
333350

334351
async def revoke_multiple_sessions(
335352
self, session_handles: List[str], user_context: Dict[str, Any]
@@ -383,9 +400,8 @@ async def fetch_and_set_claim(
383400
if session_info is None:
384401
return False
385402

386-
# TODO: Pass tenant id
387403
access_token_payload_update = await claim.build(
388-
session_info.user_id, "pass-tenant-id", user_context
404+
session_info.user_id, session_info.tenant_id, user_context
389405
)
390406
return await self.merge_into_access_token_payload(
391407
session_handle, access_token_payload_update, user_context
@@ -463,5 +479,6 @@ async def regenerate_access_token(
463479
response["session"]["handle"],
464480
response["session"]["userId"],
465481
response["session"]["userDataInJWT"],
482+
response["session"]["tenantId"],
466483
)
467484
return RegenerateAccessTokenOkResult(session, access_token_obj)

supertokens_python/recipe/session/session_class.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ async def update_session_data_in_database(
133133
def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
134134
return self.user_id
135135

136+
def get_tenant_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
137+
return self.tenant_id
138+
136139
def get_access_token_payload(
137140
self, user_context: Union[Dict[str, Any], None] = None
138141
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)