Skip to content

Commit 09b0320

Browse files
committed
refactor: Add tenant_id variable in session functions
1 parent caec759 commit 09b0320

File tree

7 files changed

+32
-93
lines changed

7 files changed

+32
-93
lines changed

supertokens_python/recipe/emailverification/recipe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ class EmailVerificationClaimClass(BooleanClaim):
312312
def __init__(self):
313313
default_max_age_in_sec = 300
314314

315-
async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool:
315+
async def fetch_value(
316+
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
317+
) -> bool:
316318
recipe = EmailVerificationRecipe.get_instance()
317319
email_info = await recipe.get_email_for_user_id(user_id, user_context)
318320

supertokens_python/recipe/multitenancy/allowed_domains_claim.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

supertokens_python/recipe/multitenancy/recipe.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -259,41 +259,43 @@ async def login_methods_get(
259259

260260
class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]):
261261
def __init__(self):
262-
async def fetch_value(_user_id: str, user_context: Dict[str, Any]) -> List[str]:
262+
default_max_age_in_sec = 60 * 60 * 24 * 7
263+
264+
async def fetch_value(
265+
_: str, tenant_id: str, user_context: Dict[str, Any]
266+
) -> Optional[List[str]]:
263267
recipe = MultitenancyRecipe.get_instance()
264-
tenant_id = (
265-
None # TODO fetch value will be passed with tenant_id as well later
266-
)
267268

268-
if recipe.config.get_allowed_domains_for_tenant_id is None:
269-
return (
270-
[]
271-
) # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
269+
if recipe.get_allowed_domains_for_tenant_id is None:
270+
# User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
271+
return None
272272

273-
domains_res = await recipe.config.get_allowed_domains_for_tenant_id(
273+
return await recipe.get_allowed_domains_for_tenant_id(
274274
tenant_id, user_context
275275
)
276-
return domains_res
277276

278-
super().__init__(
279-
key="st-tenant-domains",
280-
fetch_value=fetch_value,
281-
default_max_age_in_sec=3600,
282-
)
277+
super().__init__("st-t-dmns", fetch_value, default_max_age_in_sec)
283278

284279
def get_value_from_payload(
285-
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
286-
) -> Union[List[str], None]:
287-
if self.key not in payload:
280+
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
281+
) -> Optional[List[str]]:
282+
_ = user_context
283+
284+
res = payload.get(self.key, {}).get("v")
285+
if res is None:
288286
return []
289-
return super().get_value_from_payload(payload, user_context)
287+
return res
290288

291289
def get_last_refetch_time(
292-
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
293-
) -> Union[int, None]:
294-
if self.key not in payload:
290+
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
291+
) -> Optional[int]:
292+
_ = user_context
293+
294+
res = payload.get(self.key, {}).get("t")
295+
if res is None:
295296
return get_timestamp_ms()
296-
return super().get_last_refetch_time(payload, user_context)
297+
298+
return res
297299

298300

299301
AllowedDomainsClaim = AllowedDomainsClaimClass()

supertokens_python/recipe/session/claim_base_classes/boolean_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self,
3232
key: str,
3333
fetch_value: Callable[
34-
[str, Dict[str, Any]],
34+
[str, str, Dict[str, Any]],
3535
MaybeAwaitable[Optional[bool]],
3636
],
3737
default_max_age_in_sec: Optional[int] = None,

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ async def fetch_and_set_claim(
384384
return False
385385

386386
access_token_payload_update = await claim.build(
387-
session_info.user_id, user_context
387+
session_info.user_id, tenant_id, user_context
388388
)
389389
return await self.merge_into_access_token_payload(
390390
session_handle, access_token_payload_update, user_context

supertokens_python/recipe/session/session_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ async def fetch_and_set_claim(
220220
if user_context is None:
221221
user_context = {}
222222

223-
update = await claim.build(self.get_user_id(), user_context)
223+
update = await claim.build(self.get_user_id(), tenant_id, user_context)
224224
return await self.merge_into_access_token_payload(update, user_context)
225225

226226
async def set_claim_value(

supertokens_python/recipe/session/session_request_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ async def create_new_session_in_request(
238238
final_access_token_payload = {**access_token_payload, "iss": issuer}
239239

240240
for claim in claims_added_by_other_recipes:
241-
update = await claim.build(user_id, user_context)
241+
update = await claim.build(user_id, tenant_id, user_context)
242242
final_access_token_payload = {**final_access_token_payload, **update}
243243

244244
log_debug_message("createNewSession: Access token payload built")

0 commit comments

Comments
 (0)