diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 90013a9e..63e7fe5a 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Callable, Dict, List, Tuple, Union +from typing import Awaitable, Callable, Dict, List, Tuple, Union from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -39,6 +39,7 @@ from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( + AsyncSubscription, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -71,7 +72,6 @@ SignInWithSSOCredentials, SignOutOptions, SignUpWithPasswordCredentials, - Subscription, UserAttributes, UserIdentity, UserResponse, @@ -111,7 +111,7 @@ def __init__( self._in_memory_session: Union[Session, None] = None self._refresh_token_timer: Union[Timer, None] = None self._network_retries = 0 - self._state_change_emitters: Dict[str, Subscription] = {} + self._state_change_emitters: Dict[str, AsyncSubscription] = {} self._flow_type = flow_type self.admin = AsyncGoTrueAdminAPI( @@ -146,9 +146,9 @@ async def initialize_from_url(self, url: str) -> None: if self._is_implicit_grant_flow(url): session, redirect_type = await self._get_session_from_url(url) await self._save_session(session) - self._notify_all_subscribers("SIGNED_IN", session) + await self._notify_all_subscribers("SIGNED_IN", session) if redirect_type == "recovery": - self._notify_all_subscribers("PASSWORD_RECOVERY", session) + await self._notify_all_subscribers("PASSWORD_RECOVERY", session) except Exception as e: await self._remove_session() raise e @@ -180,7 +180,7 @@ async def sign_in_anonymously( ) if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response async def sign_up( @@ -235,7 +235,7 @@ async def sign_up( ) if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response async def sign_in_with_password( @@ -292,7 +292,7 @@ async def sign_in_with_password( ) if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response async def sign_in_with_id_token( @@ -330,7 +330,7 @@ async def sign_in_with_id_token( if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): @@ -574,7 +574,7 @@ async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: ) if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response async def reauthenticate(self) -> AuthResponse: @@ -649,7 +649,7 @@ async def update_user(self, attributes: UserAttributes) -> UserResponse: ) session.user = response.user await self._save_session(session) - self._notify_all_subscribers("USER_UPDATED", session) + await self._notify_all_subscribers("USER_UPDATED", session) return response async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: @@ -695,7 +695,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon expires_at=expires_at, ) await self._save_session(session) - self._notify_all_subscribers("TOKEN_REFRESHED", session) + await self._notify_all_subscribers("TOKEN_REFRESHED", session) return AuthResponse(session=session, user=response.user) async def refresh_session( @@ -737,12 +737,12 @@ async def sign_out(self, options: SignOutOptions = {"scope": "global"}) -> None: if options["scope"] != "others": await self._remove_session() - self._notify_all_subscribers("SIGNED_OUT", None) + await self._notify_all_subscribers("SIGNED_OUT", None) def on_auth_state_change( self, - callback: Callable[[AuthChangeEvent, Union[Session, None]], None], - ) -> Subscription: + callback: Callable[[AuthChangeEvent, Union[Session, None]], Awaitable[None]], + ) -> AsyncSubscription: """ Receive a notification every time an auth event happens. """ @@ -751,7 +751,7 @@ def on_auth_state_change( def _unsubscribe() -> None: self._state_change_emitters.pop(unique_id) - subscription = Subscription( + subscription = AsyncSubscription( id=unique_id, callback=callback, unsubscribe=_unsubscribe, @@ -855,7 +855,7 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: ) session = model_validate(Session, model_dump(response)) await self._save_session(session) - self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) + await self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) return response async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: @@ -995,7 +995,7 @@ async def _recover_and_refresh(self) -> None: return if self._persist_session: await self._save_session(current_session) - self._notify_all_subscribers("SIGNED_IN", current_session) + await self._notify_all_subscribers("SIGNED_IN", current_session) async def _call_refresh_token(self, refresh_token: str) -> Session: if not refresh_token: @@ -1004,7 +1004,7 @@ async def _call_refresh_token(self, refresh_token: str) -> Session: if not response.session: raise AuthSessionMissingError() await self._save_session(response.session) - self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + await self._notify_all_subscribers("TOKEN_REFRESHED", response.session) return response.session async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: @@ -1057,13 +1057,13 @@ async def refresh_token_function(): self._refresh_token_timer = Timer(value, refresh_token_function) self._refresh_token_timer.start() - def _notify_all_subscribers( + async def _notify_all_subscribers( self, event: AuthChangeEvent, session: Union[Session, None], ) -> None: for subscription in self._state_change_emitters.values(): - subscription.callback(event, session) + await subscription.callback(event, session) def _get_valid_session( self, @@ -1147,5 +1147,5 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): await self._storage.remove_item(f"{self._storage_key}-code-verifier") if response.session: await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + await self._notify_all_subscribers("SIGNED_IN", response.session) return response diff --git a/supabase_auth/types.py b/supabase_auth/types.py index bf5ba2d4..2b872b6a 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -2,7 +2,7 @@ from datetime import datetime from time import time -from typing import Any, Callable, Dict, List, Union +from typing import Any, Awaitable, Callable, Dict, List, Union from pydantic import BaseModel, ConfigDict @@ -276,6 +276,21 @@ class Subscription(BaseModel): """ +class AsyncSubscription(BaseModel): + id: str + """ + The subscriber UUID. This will be set by the client. + """ + callback: Callable[[AuthChangeEvent, Union[Session, None]], Awaitable[None]] + """ + The async function to call every time there is an event. + """ + unsubscribe: Callable[[], None] + """ + Call this to remove the listener. + """ + + class UpdatableFactorAttributes(TypedDict): friendly_name: str