Skip to content

Commit 506f154

Browse files
committed
tests: Add more tests and minor refactors
1 parent ddaed87 commit 506f154

File tree

7 files changed

+180
-62
lines changed

7 files changed

+180
-62
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
### Changes
12+
13+
- Use `useStaticSigningKey` instead of `use_static_signing_key` in `create_jwt` function. This was a bug in the code.
14+
1115

1216
## [0.14.3] - 2023-06-7
1317

supertokens_python/recipe/session/access_token.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,33 @@ def get_info_from_access_token(
5252
):
5353
try:
5454
payload: Optional[Dict[str, Any]] = None
55+
decode_algo = (
56+
jwt_info.parsed_header["alg"]
57+
if jwt_info.parsed_header is not None
58+
else "RS256"
59+
)
60+
5561
if jwt_info.version >= 3:
5662
matching_key = jwk_client.get_matching_key_from_jwt(
5763
jwt_info.raw_token_string
5864
)
5965
payload = jwt.decode( # type: ignore
6066
jwt_info.raw_token_string,
6167
matching_key.key, # type: ignore
62-
algorithms=["RS256"],
68+
algorithms=[decode_algo],
6369
options={"verify_signature": True, "verify_exp": True},
6470
)
6571
else:
6672
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
6773
# If any of them work, we'll use that payload
6874
for k in jwk_client.get_latest_keys():
6975
try:
70-
payload = jwt.decode(jwt_info.raw_token_string, k.key, algorithms=["RS256"]) # type: ignore
76+
payload = jwt.decode( # type: ignore
77+
jwt_info.raw_token_string,
78+
k.key, # type: ignore
79+
algorithms=[decode_algo],
80+
options={"verify_signature": True, "verify_exp": True},
81+
)
7182
break
7283
except DecodeError:
7384
pass

supertokens_python/recipe/session/jwt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
payload: Dict[str, Any],
4343
signature: str,
4444
kid: Optional[str],
45+
parsed_header: Optional[Dict[str, Any]] = None,
4546
) -> None:
4647
self.version = version
4748
self.raw_token_string = raw_token_string
@@ -50,24 +51,26 @@ def __init__(
5051
self.payload = payload
5152
self.signature = signature
5253
self.kid = kid
54+
self.parsed_header = parsed_header
5355

5456

5557
def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
5658
splitted_input = jwt.split(".")
57-
latest_access_token_version = 3
59+
TOKEN_V3 = 3
5860
if len(splitted_input) != 3:
5961
raise Exception("invalid jwt")
6062

6163
# V1 and V2 are functionally identical, plus all legacy tokens should be V2 now.
6264
# So we can assume these defaults:
6365
version = 2
6466
kid = None
67+
parsed_header = None
6568
# V2 or older tokens didn't save the key id
6669
header, payload, signature = splitted_input
6770
# checking the header
6871
if header not in _allowed_headers:
6972
parsed_header = loads(utf_base64decode(header, True))
70-
header_version = parsed_header.get("version", str(latest_access_token_version))
73+
header_version = parsed_header.get("version", str(TOKEN_V3))
7174

7275
try:
7376
version = int(header_version)
@@ -79,7 +82,7 @@ def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
7982
if (
8083
parsed_header["typ"] != "JWT"
8184
or not isinstance(version, int)
82-
or version < latest_access_token_version
85+
or version < TOKEN_V3
8386
or kid is None
8487
):
8588
raise Exception("JWT header mismatch")
@@ -94,4 +97,5 @@ def parse_jwt_without_signature_verification(jwt: str) -> ParsedJWTInfo:
9497
payload=loads(utf_base64decode(payload, True)),
9598
signature=signature,
9699
kid=kid,
100+
parsed_header=parsed_header,
97101
)

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from supertokens_python.logger import log_debug_message
2020
from supertokens_python.normalised_url_path import NormalisedURLPath
21-
from supertokens_python.utils import resolve
21+
from supertokens_python.utils import resolve, RWMutex, RWLockContext
2222

2323
from ...types import MaybeAwaitable
2424
from . import session_functions
@@ -52,7 +52,6 @@
5252

5353

5454
from typing_extensions import TypedDict
55-
import threading
5655
from os import environ
5756

5857

@@ -66,61 +65,6 @@ class JWKSConfigType(TypedDict):
6665
"refresh_rate_limit": 500,
6766
}
6867

69-
70-
class RWMutex:
71-
def __init__(self):
72-
self._lock = threading.Lock()
73-
self._readers = threading.Condition(self._lock)
74-
self._writers = threading.Condition(self._lock)
75-
self._reader_count = 0
76-
self._writer_count = 0
77-
78-
def lock(self):
79-
with self._lock:
80-
while self._writer_count > 0 or self._reader_count > 0:
81-
self._writers.wait()
82-
self._writer_count += 1
83-
84-
def unlock(self):
85-
with self._lock:
86-
self._writer_count -= 1
87-
self._readers.notify_all()
88-
self._writers.notify_all()
89-
90-
def r_lock(self):
91-
with self._lock:
92-
while self._writer_count > 0:
93-
self._readers.wait()
94-
self._reader_count += 1
95-
96-
def r_unlock(self):
97-
with self._lock:
98-
self._reader_count -= 1
99-
if self._reader_count == 0:
100-
self._writers.notify_all()
101-
102-
103-
class RWLockContext:
104-
def __init__(self, mutex: RWMutex, read: bool = True):
105-
self.mutex = mutex
106-
self.read = read
107-
108-
def __enter__(self):
109-
if self.read:
110-
self.mutex.r_lock()
111-
else:
112-
self.mutex.lock()
113-
114-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
115-
if exc_type is not None:
116-
raise exc_type(exc_value).with_traceback(traceback)
117-
118-
if self.read:
119-
self.mutex.r_unlock()
120-
else:
121-
self.mutex.unlock()
122-
123-
12468
cached_jwk_client: Optional[JWKClient] = None
12569
mutex = RWMutex()
12670

supertokens_python/utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
import json
19+
import threading
1920
import warnings
2021
from base64 import urlsafe_b64decode, urlsafe_b64encode, b64encode, b64decode
2122
from math import floor
@@ -305,3 +306,57 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str:
305306
)
306307

307308
return parsed_url.domain + "." + parsed_url.suffix # type: ignore
309+
310+
311+
class RWMutex:
312+
def __init__(self):
313+
self._lock = threading.Lock()
314+
self._readers = threading.Condition(self._lock)
315+
self._writers = threading.Condition(self._lock)
316+
self._reader_count = 0
317+
self._writer_count = 0
318+
319+
def lock(self):
320+
with self._lock:
321+
while self._writer_count > 0 or self._reader_count > 0:
322+
self._writers.wait()
323+
self._writer_count += 1
324+
325+
def unlock(self):
326+
with self._lock:
327+
self._writer_count -= 1
328+
self._readers.notify_all()
329+
self._writers.notify_all()
330+
331+
def r_lock(self):
332+
with self._lock:
333+
while self._writer_count > 0:
334+
self._readers.wait()
335+
self._reader_count += 1
336+
337+
def r_unlock(self):
338+
with self._lock:
339+
self._reader_count -= 1
340+
if self._reader_count == 0:
341+
self._writers.notify_all()
342+
343+
344+
class RWLockContext:
345+
def __init__(self, mutex: RWMutex, read: bool = True):
346+
self.mutex = mutex
347+
self.read = read
348+
349+
def __enter__(self):
350+
if self.read:
351+
self.mutex.r_lock()
352+
else:
353+
self.mutex.lock()
354+
355+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
356+
if exc_type is not None:
357+
raise exc_type(exc_value).with_traceback(traceback)
358+
359+
if self.read:
360+
self.mutex.r_unlock()
361+
else:
362+
self.mutex.unlock()

tests/sessions/test_jwks.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,43 @@ async def test_session_verification_of_jwt_based_on_session_payload_with_check_d
462462
assert s_.get_user_id() == "userId"
463463

464464

465+
async def test_session_verification_of_jwt_with_dynamic_signing_key():
466+
init(
467+
**get_st_init_args(
468+
recipe_list=[session.init(use_dynamic_access_token_signing_key=False)]
469+
)
470+
)
471+
start_st()
472+
473+
s = await create_new_session_without_request_response("userId", {}, {})
474+
475+
payload = s.get_access_token_payload()
476+
del payload["iat"]
477+
del payload["exp"]
478+
payload["tId"] = "public" # tenant id
479+
480+
now = get_timestamp_ms()
481+
jwt_expiry = now + 10 * 1000 # expiry jwt after 10sec
482+
483+
jwt_with_dynamic_key = await create_jwt(
484+
payload, jwt_expiry, use_static_signing_key=False
485+
)
486+
assert isinstance(jwt_with_dynamic_key, CreateJwtOkResult)
487+
try:
488+
await get_session_without_request_response(jwt_with_dynamic_key.jwt)
489+
assert False
490+
except Exception:
491+
pass
492+
493+
jwt_with_static_key = await create_jwt(
494+
payload, jwt_expiry, use_static_signing_key=True
495+
)
496+
assert isinstance(jwt_with_static_key, CreateJwtOkResult)
497+
s_ = await get_session_without_request_response(jwt_with_static_key.jwt)
498+
assert s_ is not None
499+
assert s_.get_user_id() == "userId"
500+
501+
465502
async def test_that_locking_for_jwks_cache_works(caplog: LogCaptureFixture):
466503
caplog.set_level(logging.DEBUG)
467504
not_returned_from_cache_count = get_log_occurence_count(

tests/test_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from typing import Union, List, Any, Dict
22

33
import pytest
4+
import threading
5+
46
from supertokens_python.utils import humanize_time, is_version_gte
7+
from supertokens_python.utils import RWMutex
8+
59
from tests.utils import is_subset
610

711

@@ -102,3 +106,62 @@ def test_is_subset(
102106
assert is_subset(d1, d2)
103107
else:
104108
assert not is_subset(d1, d2)
109+
110+
111+
class BankAccount:
112+
def __init__(self):
113+
self.balance = 0
114+
self.mutex = RWMutex()
115+
116+
def deposit(self, amount: int):
117+
self.mutex.lock()
118+
self.balance += amount
119+
self.mutex.unlock()
120+
121+
def withdraw(self, amount: int):
122+
self.mutex.lock()
123+
self.balance -= amount
124+
self.mutex.unlock()
125+
126+
def get_balance(self):
127+
self.mutex.r_lock()
128+
balance = self.balance
129+
self.mutex.r_unlock()
130+
return balance
131+
132+
133+
def test_rw_mutex_writes():
134+
account = BankAccount()
135+
threads: List[threading.Thread] = []
136+
137+
# Create 10 deposit threads
138+
for _ in range(10):
139+
t = threading.Thread(target=account.deposit, args=(10,))
140+
threads.append(t)
141+
142+
def balance_is_valid():
143+
balance = account.get_balance()
144+
assert balance % 5 == 0 and balance >= 0
145+
146+
# Create 15 balance checking threads
147+
for _ in range(15):
148+
t = threading.Thread(target=balance_is_valid)
149+
threads.append(t)
150+
151+
# Create 10 withdraw threads
152+
for _ in range(10):
153+
t = threading.Thread(target=account.withdraw, args=(5,))
154+
threads.append(t)
155+
156+
# Start all threads
157+
for t in threads:
158+
t.start()
159+
160+
# Wait for all threads to finish
161+
for t in threads:
162+
t.join()
163+
164+
# Check account balance
165+
expected_balance = 10 * 10 # 10 threads depositing 10 each
166+
expected_balance -= 10 * 5 # 10 threads withdrawing 5 each
167+
assert account.get_balance() == expected_balance, "Incorrect account balance"

0 commit comments

Comments
 (0)