Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/auth/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ select = [
# "SIM",
# isort
"I",
"ANN2"
]
ignore = ["F401", "F403", "F841", "E712", "E501", "E402", "E722", "E731", "UP006", "UP035"]
# isort.required-imports = ["from __future__ import annotations"]
ignore = ["F403", "E501", "E402", "UP006", "UP035"]

[tool.ruff.lint.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
Expand Down
4 changes: 2 additions & 2 deletions src/auth/scripts/gh-download.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def download_directory(repository: Repository, sha: str, server_path: str) -> No
print("Error processing %s: %s", content.path, exc)


def usage():
def usage() -> None:
"""
Prints the usage command lines
"""
print("usage: gh-download --repo=repo --branch=branch --folder=folder")


def main(argv):
def main(argv) -> None:
"""
Main function block
"""
Expand Down
18 changes: 9 additions & 9 deletions src/auth/src/supabase_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI
from ._async.gotrue_client import AsyncGoTrueClient
from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI # noqa
from ._async.gotrue_client import AsyncGoTrueClient # noqa
from ._async.storage import (
AsyncMemoryStorage,
AsyncSupportedStorage,
AsyncMemoryStorage, # noqa
AsyncSupportedStorage, # noqa
)
from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI
from ._sync.gotrue_client import SyncGoTrueClient
from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI # noqa
from ._sync.gotrue_client import SyncGoTrueClient # noqa
from ._sync.storage import (
SyncMemoryStorage,
SyncSupportedStorage,
SyncMemoryStorage, # noqa
SyncSupportedStorage, # noqa
)
from .types import *
from .version import __version__
from .version import __version__ # noqa
5 changes: 2 additions & 3 deletions src/auth/src/supabase_auth/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from httpx import QueryParams, Response
from pydantic import TypeAdapter
from httpx import QueryParams

from ..helpers import (
model_validate,
Expand Down
4 changes: 2 additions & 2 deletions src/auth/src/supabase_auth/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Dict, Optional, TypeVar, overload
from typing import Any, Dict, Optional

from httpx import HTTPStatusError, QueryParams, Response
from pydantic import BaseModel
Expand All @@ -20,7 +20,7 @@ def __init__(
http_client: Optional[AsyncClient],
verify: bool = True,
proxy: Optional[str] = None,
):
) -> None:
self._url = url
self._headers = headers
self._http_client = http_client or AsyncClient(
Expand Down
23 changes: 11 additions & 12 deletions src/auth/src/supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import time
from contextlib import suppress
from functools import partial
from json import loads
from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
from typing import Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

from httpx import QueryParams
from httpx import QueryParams, Response
from jwt import get_algorithm_by_name
from typing_extensions import cast

Expand All @@ -33,7 +31,6 @@
decode_jwt,
generate_pkce_challenge,
generate_pkce_verifier,
model_dump,
model_dump_json,
model_validate,
parse_auth_otp_response,
Expand All @@ -50,7 +47,6 @@
JWK,
AMREntry,
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
Expand Down Expand Up @@ -86,6 +82,7 @@
SignUpWithEmailAndPasswordCredentialsOptions,
SignUpWithPasswordCredentials,
SignUpWithPhoneAndPasswordCredentialsOptions,
SSOResponse,
Subscription,
UpdateUserOptions,
UserAttributes,
Expand Down Expand Up @@ -361,7 +358,9 @@ async def sign_in_with_id_token(
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
return auth_response

async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials):
async def sign_in_with_sso(
self, credentials: SignInWithSSOCredentials
) -> SSOResponse:
"""
Attempts a single-sign on using an enterprise Identity Provider. A
successful SSO attempt will redirect the current page to the identity
Expand Down Expand Up @@ -476,7 +475,7 @@ async def get_user_identities(self) -> IdentitiesResponse:
return IdentitiesResponse(identities=response.user.identities or [])
raise AuthSessionMissingError()

async def unlink_identity(self, identity: UserIdentity):
async def unlink_identity(self, identity: UserIdentity) -> Response:
session = await self.get_session()
if not session:
raise AuthSessionMissingError()
Expand Down Expand Up @@ -621,7 +620,7 @@ async def reauthenticate(self) -> AuthResponse:
if not session:
raise AuthSessionMissingError()

response = await self._request(
await self._request(
"GET",
"reauthenticate",
jwt=session.access_token,
Expand Down Expand Up @@ -1090,7 +1089,7 @@ async def _start_auto_refresh_token(self, value: float) -> None:
if value <= 0 or not self._auto_refresh_token:
return

async def refresh_token_function():
async def refresh_token_function() -> None:
self._network_retries += 1
try:
session = await self.get_session()
Expand Down Expand Up @@ -1275,7 +1274,7 @@ def __del__(self) -> None:
try:
# Try to cancel the timer
self._refresh_token_timer.cancel()
except:
except Exception:
# Ignore errors if event loop is closed or selector is not registered
pass
finally:
Expand Down
5 changes: 2 additions & 3 deletions src/auth/src/supabase_auth/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from httpx import QueryParams, Response
from pydantic import TypeAdapter
from httpx import QueryParams

from ..helpers import (
model_validate,
Expand Down
4 changes: 2 additions & 2 deletions src/auth/src/supabase_auth/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Dict, Optional, TypeVar, overload
from typing import Any, Dict, Optional

from httpx import HTTPStatusError, QueryParams, Response
from pydantic import BaseModel
Expand All @@ -20,7 +20,7 @@ def __init__(
http_client: Optional[SyncClient],
verify: bool = True,
proxy: Optional[str] = None,
):
) -> None:
self._url = url
self._headers = headers
self._http_client = http_client or SyncClient(
Expand Down
21 changes: 9 additions & 12 deletions src/auth/src/supabase_auth/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import time
from contextlib import suppress
from functools import partial
from json import loads
from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
from typing import Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

from httpx import QueryParams
from httpx import QueryParams, Response
from jwt import get_algorithm_by_name
from typing_extensions import cast

Expand All @@ -33,7 +31,6 @@
decode_jwt,
generate_pkce_challenge,
generate_pkce_verifier,
model_dump,
model_dump_json,
model_validate,
parse_auth_otp_response,
Expand All @@ -50,7 +47,6 @@
JWK,
AMREntry,
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
Expand Down Expand Up @@ -86,6 +82,7 @@
SignUpWithEmailAndPasswordCredentialsOptions,
SignUpWithPasswordCredentials,
SignUpWithPhoneAndPasswordCredentialsOptions,
SSOResponse,
Subscription,
UpdateUserOptions,
UserAttributes,
Expand Down Expand Up @@ -361,7 +358,7 @@ def sign_in_with_id_token(
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
return auth_response

def sign_in_with_sso(self, credentials: SignInWithSSOCredentials):
def sign_in_with_sso(self, credentials: SignInWithSSOCredentials) -> SSOResponse:
"""
Attempts a single-sign on using an enterprise Identity Provider. A
successful SSO attempt will redirect the current page to the identity
Expand Down Expand Up @@ -474,7 +471,7 @@ def get_user_identities(self) -> IdentitiesResponse:
return IdentitiesResponse(identities=response.user.identities or [])
raise AuthSessionMissingError()

def unlink_identity(self, identity: UserIdentity):
def unlink_identity(self, identity: UserIdentity) -> Response:
session = self.get_session()
if not session:
raise AuthSessionMissingError()
Expand Down Expand Up @@ -619,7 +616,7 @@ def reauthenticate(self) -> AuthResponse:
if not session:
raise AuthSessionMissingError()

response = self._request(
self._request(
"GET",
"reauthenticate",
jwt=session.access_token,
Expand Down Expand Up @@ -1086,7 +1083,7 @@ def _start_auto_refresh_token(self, value: float) -> None:
if value <= 0 or not self._auto_refresh_token:
return

def refresh_token_function():
def refresh_token_function() -> None:
self._network_retries += 1
try:
session = self.get_session()
Expand Down Expand Up @@ -1267,7 +1264,7 @@ def __del__(self) -> None:
try:
# Try to cancel the timer
self._refresh_token_timer.cancel()
except:
except Exception:
# Ignore errors if event loop is closed or selector is not registered
pass
finally:
Expand Down
2 changes: 1 addition & 1 deletion src/auth/src/supabase_auth/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@


class UserDoesntExist(Exception):
def __init__(self, access_token: str):
def __init__(self, access_token: str) -> None:
self.access_token = access_token


Expand Down
22 changes: 11 additions & 11 deletions src/auth/src/supabase_auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
import uuid
from base64 import urlsafe_b64decode
from datetime import datetime
from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, Union, cast
from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, Union
from urllib.parse import urlparse

from httpx import HTTPStatusError, Response
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, TypeAdapter, ValidationError

from .constants import (
API_VERSION_HEADER_NAME,
API_VERSIONS_2024_01_01_TIMESTAMP,
BASE64URL_REGEX,
)
from .errors import (
AuthApiError,
Expand Down Expand Up @@ -81,7 +80,7 @@ def parse_auth_response(response: Response) -> AuthResponse:
try:
session = model_validate(Session, response.content)
user = session.user
except:
except ValidationError:
session = None
user = model_validate(User, response.content)
return AuthResponse(user=user, session=session)
Expand Down Expand Up @@ -126,9 +125,10 @@ def parse_jwks(response: Response) -> JWKSet:

def get_error_message(error: Any) -> str:
props = ["msg", "message", "error_description", "error"]
filter = lambda prop: (
prop in error if isinstance(error, dict) else hasattr(error, prop)
)

def filter(prop) -> bool:
return prop in error if isinstance(error, dict) else hasattr(error, prop)

return next((error[prop] for prop in props if filter(prop)), str(error))


Expand Down Expand Up @@ -240,7 +240,7 @@ def decode_jwt(token: str) -> DecodedJWT:
)


def generate_pkce_verifier(length=64):
def generate_pkce_verifier(length=64) -> str:
"""Generate a random PKCE verifier of the specified length."""
if length < 43 or length > 128:
raise ValueError("PKCE verifier length must be between 43 and 128 characters")
Expand All @@ -251,7 +251,7 @@ def generate_pkce_verifier(length=64):
return "".join(secrets.choice(charset) for _ in range(length))


def generate_pkce_challenge(code_verifier):
def generate_pkce_challenge(code_verifier) -> str:
"""Generate a code challenge from a PKCE verifier."""
# Hash the verifier using SHA-256
verifier_bytes = code_verifier.encode("utf-8")
Expand All @@ -263,7 +263,7 @@ def generate_pkce_challenge(code_verifier):
API_VERSION_REGEX = r"^2[0-9]{3}-(0[1-9]|1[0-2])-(0[1-9]|1[0-9]|2[0-9]|3[0-1])$"


def parse_response_api_version(response: Response):
def parse_response_api_version(response: Response) -> Optional[datetime]:
api_version = response.headers.get(API_VERSION_HEADER_NAME)

if not api_version:
Expand All @@ -275,7 +275,7 @@ def parse_response_api_version(response: Response):
try:
dt = datetime.strptime(api_version, "%Y-%m-%d")
return dt
except Exception as e:
except Exception:
return None


Expand Down
4 changes: 2 additions & 2 deletions src/auth/src/supabase_auth/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def __init__(
def start(self) -> None:
if asyncio.iscoroutinefunction(self._function):

async def schedule():
async def schedule() -> None:
await asyncio.sleep(self._milliseconds / 1000)
await cast(Coroutine[Any, Any, None], self._function())

def cleanup(_):
def cleanup(_) -> None:
self._task = None

self._task = asyncio.create_task(schedule())
Expand Down
Loading