Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
from typing import Awaitable, Callable, Dict, List, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
from uuid import uuid4

Expand Down Expand Up @@ -39,6 +39,7 @@
from ..http_clients import AsyncClient
from ..timer import Timer
from ..types import (
AsyncSubscription,
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
Expand Down Expand Up @@ -71,7 +72,6 @@
SignInWithSSOCredentials,
SignOutOptions,
SignUpWithPasswordCredentials,
Subscription,
UserAttributes,
UserIdentity,
UserResponse,
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
self._in_memory_session: Union[Session, None] = None
self._refresh_token_timer: Union[Timer, None] = None
self._network_retries = 0
self._state_change_emitters: Dict[str, Subscription] = {}
self._state_change_emitters: Dict[str, AsyncSubscription] = {}
self._flow_type = flow_type

self.admin = AsyncGoTrueAdminAPI(
Expand Down Expand Up @@ -146,9 +146,9 @@ async def initialize_from_url(self, url: str) -> None:
if self._is_implicit_grant_flow(url):
session, redirect_type = await self._get_session_from_url(url)
await self._save_session(session)
self._notify_all_subscribers("SIGNED_IN", session)
await self._notify_all_subscribers("SIGNED_IN", session)
if redirect_type == "recovery":
self._notify_all_subscribers("PASSWORD_RECOVERY", session)
await self._notify_all_subscribers("PASSWORD_RECOVERY", session)
except Exception as e:
await self._remove_session()
raise e
Expand Down Expand Up @@ -180,7 +180,7 @@ async def sign_in_anonymously(
)
if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def sign_up(
Expand Down Expand Up @@ -235,7 +235,7 @@ async def sign_up(
)
if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def sign_in_with_password(
Expand Down Expand Up @@ -292,7 +292,7 @@ async def sign_in_with_password(
)
if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def sign_in_with_id_token(
Expand Down Expand Up @@ -330,7 +330,7 @@ async def sign_in_with_id_token(

if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials):
Expand Down Expand Up @@ -574,7 +574,7 @@ async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse:
)
if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def reauthenticate(self) -> AuthResponse:
Expand Down Expand Up @@ -649,7 +649,7 @@ async def update_user(self, attributes: UserAttributes) -> UserResponse:
)
session.user = response.user
await self._save_session(session)
self._notify_all_subscribers("USER_UPDATED", session)
await self._notify_all_subscribers("USER_UPDATED", session)
return response

async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
Expand Down Expand Up @@ -695,7 +695,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
expires_at=expires_at,
)
await self._save_session(session)
self._notify_all_subscribers("TOKEN_REFRESHED", session)
await self._notify_all_subscribers("TOKEN_REFRESHED", session)
return AuthResponse(session=session, user=response.user)

async def refresh_session(
Expand Down Expand Up @@ -737,12 +737,12 @@ async def sign_out(self, options: SignOutOptions = {"scope": "global"}) -> None:

if options["scope"] != "others":
await self._remove_session()
self._notify_all_subscribers("SIGNED_OUT", None)
await self._notify_all_subscribers("SIGNED_OUT", None)

def on_auth_state_change(
self,
callback: Callable[[AuthChangeEvent, Union[Session, None]], None],
) -> Subscription:
callback: Callable[[AuthChangeEvent, Union[Session, None]], Awaitable[None]],
) -> AsyncSubscription:
"""
Receive a notification every time an auth event happens.
"""
Expand All @@ -751,7 +751,7 @@ def on_auth_state_change(
def _unsubscribe() -> None:
self._state_change_emitters.pop(unique_id)

subscription = Subscription(
subscription = AsyncSubscription(
id=unique_id,
callback=callback,
unsubscribe=_unsubscribe,
Expand Down Expand Up @@ -855,7 +855,7 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
)
session = model_validate(Session, model_dump(response))
await self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
await self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response

async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
Expand Down Expand Up @@ -995,7 +995,7 @@ async def _recover_and_refresh(self) -> None:
return
if self._persist_session:
await self._save_session(current_session)
self._notify_all_subscribers("SIGNED_IN", current_session)
await self._notify_all_subscribers("SIGNED_IN", current_session)

async def _call_refresh_token(self, refresh_token: str) -> Session:
if not refresh_token:
Expand All @@ -1004,7 +1004,7 @@ async def _call_refresh_token(self, refresh_token: str) -> Session:
if not response.session:
raise AuthSessionMissingError()
await self._save_session(response.session)
self._notify_all_subscribers("TOKEN_REFRESHED", response.session)
await self._notify_all_subscribers("TOKEN_REFRESHED", response.session)
return response.session

async def _refresh_access_token(self, refresh_token: str) -> AuthResponse:
Expand Down Expand Up @@ -1057,13 +1057,13 @@ async def refresh_token_function():
self._refresh_token_timer = Timer(value, refresh_token_function)
self._refresh_token_timer.start()

def _notify_all_subscribers(
async def _notify_all_subscribers(
self,
event: AuthChangeEvent,
session: Union[Session, None],
) -> None:
for subscription in self._state_change_emitters.values():
subscription.callback(event, session)
await subscription.callback(event, session)

def _get_valid_session(
self,
Expand Down Expand Up @@ -1147,5 +1147,5 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
await self._storage.remove_item(f"{self._storage_key}-code-verifier")
if response.session:
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
await self._notify_all_subscribers("SIGNED_IN", response.session)
return response
17 changes: 16 additions & 1 deletion supabase_auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import datetime
from time import time
from typing import Any, Callable, Dict, List, Union
from typing import Any, Awaitable, Callable, Dict, List, Union

from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -276,6 +276,21 @@ class Subscription(BaseModel):
"""


class AsyncSubscription(BaseModel):
id: str
"""
The subscriber UUID. This will be set by the client.
"""
callback: Callable[[AuthChangeEvent, Union[Session, None]], Awaitable[None]]
"""
The async function to call every time there is an event.
"""
unsubscribe: Callable[[], None]
"""
Call this to remove the listener.
"""


class UpdatableFactorAttributes(TypedDict):
friendly_name: str

Expand Down