@@ -100,6 +100,9 @@ def __init__(
100100 )
101101
102102 self .static_third_party_providers : List [ProviderInput ] = []
103+ self .get_allowed_domains_for_tenant_id = (
104+ self .config .get_allowed_domains_for_tenant_id
105+ )
103106
104107 def is_error_from_this_recipe_based_on_instance (self , err : Exception ) -> bool :
105108 return isinstance (err , (TenantDoesNotExistError , RecipeDisabledForTenantError ))
@@ -259,41 +262,43 @@ async def login_methods_get(
259262
260263class AllowedDomainsClaimClass (PrimitiveArrayClaim [List [str ]]):
261264 def __init__ (self ):
262- async def fetch_value (_user_id : str , user_context : Dict [str , Any ]) -> List [str ]:
265+ default_max_age_in_sec = 60 * 60
266+
267+ async def fetch_value (
268+ _ : str , tenant_id : str , user_context : Dict [str , Any ]
269+ ) -> Optional [List [str ]]:
263270 recipe = MultitenancyRecipe .get_instance ()
264- tenant_id = (
265- None # TODO fetch value will be passed with tenant_id as well later
266- )
267271
268- if recipe .config .get_allowed_domains_for_tenant_id is None :
269- return (
270- []
271- ) # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
272+ if recipe .get_allowed_domains_for_tenant_id is None :
273+ # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
274+ return None
272275
273- domains_res = await recipe . config .get_allowed_domains_for_tenant_id (
276+ return await recipe .get_allowed_domains_for_tenant_id (
274277 tenant_id , user_context
275278 )
276- return domains_res
277279
278- super ().__init__ (
279- key = "st-tenant-domains" ,
280- fetch_value = fetch_value ,
281- default_max_age_in_sec = 3600 ,
282- )
280+ super ().__init__ ("st-t-dmns" , fetch_value , default_max_age_in_sec )
283281
284282 def get_value_from_payload (
285- self , payload : JSONObject , user_context : Union [Dict [str , Any ], None ] = None
286- ) -> Union [List [str ], None ]:
287- if self .key not in payload :
283+ self , payload : JSONObject , user_context : Optional [Dict [str , Any ]] = None
284+ ) -> Optional [List [str ]]:
285+ _ = user_context
286+
287+ res = payload .get (self .key , {}).get ("v" )
288+ if res is None :
288289 return []
289- return super (). get_value_from_payload ( payload , user_context )
290+ return res
290291
291292 def get_last_refetch_time (
292- self , payload : JSONObject , user_context : Union [Dict [str , Any ], None ] = None
293- ) -> Union [int , None ]:
294- if self .key not in payload :
293+ self , payload : JSONObject , user_context : Optional [Dict [str , Any ]] = None
294+ ) -> Optional [int ]:
295+ _ = user_context
296+
297+ res = payload .get (self .key , {}).get ("t" )
298+ if res is None :
295299 return get_timestamp_ms ()
296- return super ().get_last_refetch_time (payload , user_context )
300+
301+ return res
297302
298303
299304AllowedDomainsClaim = AllowedDomainsClaimClass ()
0 commit comments