Skip to content

Commit 1314595

Browse files
committed
update is_valid_jwt function
1 parent 3e0934f commit 1314595

File tree

3 files changed

+35
-17
lines changed

3 files changed

+35
-17
lines changed

supabase/_async/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT
1717
from supafunc import AsyncFunctionsClient
1818

19-
from supabase.lib.helpers import is_jwt
19+
from supabase.lib.helpers import is_valid_jwt
2020

2121
from ..lib.client_options import AsyncClientOptions as ClientOptions
2222
from .auth_client import AsyncSupabaseAuthClient
@@ -280,7 +280,7 @@ def _create_auth_header(self, token: str):
280280

281281
def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]:
282282
if authorization is None:
283-
if is_jwt(self.supabase_key):
283+
if is_valid_jwt(self.supabase_key):
284284
authorization = self.options.headers.get(
285285
"Authorization", self._create_auth_header(self.supabase_key)
286286
)
@@ -294,7 +294,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st
294294
def _listen_to_auth_events(
295295
self, event: AuthChangeEvent, session: Optional[Session]
296296
):
297-
default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None
297+
default_access_token = (
298+
self.supabase_key if is_valid_jwt(self.supabase_key) else None
299+
)
298300
access_token = default_access_token
299301
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
300302
# reset postgrest and storage instance on event change

supabase/_sync/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT
1616
from supafunc import SyncFunctionsClient
1717

18-
from supabase.lib.helpers import is_jwt
18+
from supabase.lib.helpers import is_valid_jwt
1919

2020
from ..lib.client_options import SyncClientOptions as ClientOptions
2121
from .auth_client import SyncSupabaseAuthClient
@@ -279,7 +279,7 @@ def _create_auth_header(self, token: str):
279279

280280
def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]:
281281
if authorization is None:
282-
if is_jwt(self.supabase_key):
282+
if is_valid_jwt(self.supabase_key):
283283
authorization = self.options.headers.get(
284284
"Authorization", self._create_auth_header(self.supabase_key)
285285
)
@@ -293,7 +293,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st
293293
def _listen_to_auth_events(
294294
self, event: AuthChangeEvent, session: Optional[Session]
295295
):
296-
default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None
296+
default_access_token = (
297+
self.supabase_key if is_valid_jwt(self.supabase_key) else None
298+
)
297299
access_token = default_access_token
298300
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
299301
# reset postgrest and storage instance on event change

supabase/lib/helpers.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
11
import re
2+
from typing import Dict
23

34
BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
45

56

6-
def is_jwt(value: str) -> bool:
7-
if value.startswith("Bearer "):
8-
value = value.replace("Bearer ", "")
7+
def is_valid_jwt(value: str) -> bool:
8+
"""Checks if value looks like a JWT, does not do any extra parsing."""
9+
if not isinstance(value, str):
10+
return False
911

12+
# Remove trailing whitespaces if any.
1013
value = value.strip()
11-
if not value:
12-
return False
1314

14-
parts = value.split(".")
15-
if len(parts) != 3:
15+
# Remove "Bearer " prefix if any.
16+
if value.startswith("Bearer "):
17+
value = value[7:]
18+
19+
# Valid JWT must have 2 dots (Header.Paylod.Signature)
20+
if value.count(".") != 2:
1621
return False
1722

18-
# loop through the parts and test against regex
19-
for part in parts:
20-
if len(part) < 4 or not re.search(BASE64URL_REGEX, part, re.IGNORECASE):
23+
for part in value.split("."):
24+
if not re.search(BASE64URL_REGEX, part, re.IGNORECASE):
2125
return False
2226

2327
return True
2428

2529

26-
def check_authorization_header(headers):
30+
def check_authorization_header(headers: Dict[str, str]):
31+
authorization = headers.get("Authorization")
32+
if not authorization:
33+
return
34+
35+
if authorization.startswith("Bearer "):
36+
if not is_valid_jwt(authorization):
37+
raise ValueError(
38+
"create_client called with global Authorization header that does not contain a JWT"
39+
)
40+
2741
return True

0 commit comments

Comments
 (0)