|
2 | 2 | from unittest.mock import AsyncMock, Mock, patch |
3 | 3 | import jwt |
4 | 4 | from datetime import datetime, timezone |
| 5 | +import concurrent.futures |
5 | 6 |
|
6 | 7 | from tests.conftest import with_jwks_mock |
7 | 8 | from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache |
|
22 | 23 | class SessionFixtures: |
23 | 24 | @pytest.fixture(autouse=True) |
24 | 25 | def clear_jwks_cache(self): |
25 | | - _jwks_cache._clients.clear() |
| 26 | + _jwks_cache.clear() |
26 | 27 | yield |
27 | | - _jwks_cache._clients.clear() |
| 28 | + _jwks_cache.clear() |
28 | 29 |
|
29 | 30 | @pytest.fixture |
30 | 31 | def session_constants(self): |
@@ -520,3 +521,20 @@ def test_jwks_client_caching_different_urls(self): |
520 | 521 | # Should be different instances |
521 | 522 | assert client1 is not client2 |
522 | 523 | assert id(client1) != id(client2) |
| 524 | + |
| 525 | + def test_jwks_cache_thread_safety(self): |
| 526 | + url = "https://api.workos.com/sso/jwks/thread_test" |
| 527 | + clients = [] |
| 528 | + |
| 529 | + def get_client(): |
| 530 | + return _get_jwks_client(url) |
| 531 | + |
| 532 | + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
| 533 | + futures = [executor.submit(get_client) for _ in range(10)] |
| 534 | + clients = [future.result() for future in futures] |
| 535 | + |
| 536 | + first_client = clients[0] |
| 537 | + for client in clients[1:]: |
| 538 | + assert ( |
| 539 | + client is first_client |
| 540 | + ), "All concurrent calls should return the same instance" |
0 commit comments