Skip to content

Commit 9c91c53

Browse files
committed
Switch to lru_cache
1 parent 826bb3f commit 9c91c53

File tree

2 files changed

+7
-37
lines changed

2 files changed

+7
-37
lines changed

tests/test_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import concurrent.futures
66

77
from tests.conftest import with_jwks_mock
8-
from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache
8+
from workos.session import AsyncSession, Session, _get_jwks_client
99
from workos.types.user_management.authentication_response import (
1010
RefreshTokenAuthenticationResponse,
1111
)
@@ -23,9 +23,9 @@
2323
class SessionFixtures:
2424
@pytest.fixture(autouse=True)
2525
def clear_jwks_cache(self):
26-
_jwks_cache.clear()
26+
_get_jwks_client.cache_clear()
2727
yield
28-
_jwks_cache.clear()
28+
_get_jwks_client.cache_clear()
2929

3030
@pytest.fixture
3131
def session_constants(self):
@@ -521,7 +521,7 @@ def test_jwks_client_caching_different_urls(self):
521521
# Should be different instances
522522
assert client1 is not client2
523523
assert id(client1) != id(client2)
524-
524+
525525
def test_jwks_cache_thread_safety(self):
526526
url = "https://api.workos.com/sso/jwks/thread_test"
527527
clients = []

workos/session.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22
from typing import TYPE_CHECKING, List, Protocol
33

4+
from functools import lru_cache
45
import json
5-
import threading
66
from typing import Any, Dict, Optional, Union, cast
77
import jwt
88
from jwt import PyJWKClient
@@ -22,39 +22,9 @@
2222
from workos.user_management import AsyncUserManagement, UserManagement
2323

2424

25-
class _JWKSClientCache:
26-
def __init__(self) -> None:
27-
self._clients: Dict[str, PyJWKClient] = {}
28-
self._lock = threading.Lock()
29-
30-
def get_client(self, jwks_url: str) -> PyJWKClient:
31-
if jwks_url in self._clients:
32-
return self._clients[jwks_url]
33-
34-
with self._lock:
35-
if jwks_url in self._clients:
36-
return self._clients[jwks_url]
37-
38-
client = PyJWKClient(jwks_url)
39-
self._clients[jwks_url] = client
40-
return client
41-
42-
def clear(self) -> None:
43-
"""Intended primarily for test cleanup and manual cache invalidation.
44-
45-
Warning: If called concurrently with get_client(), some newly created
46-
clients might be lost due to lock acquisition ordering.
47-
"""
48-
with self._lock:
49-
self._clients.clear()
50-
51-
52-
# Module-level cache instance
53-
_jwks_cache = _JWKSClientCache()
54-
55-
25+
@lru_cache(maxsize=None)
5626
def _get_jwks_client(jwks_url: str) -> PyJWKClient:
57-
return _jwks_cache.get_client(jwks_url)
27+
return PyJWKClient(jwks_url)
5828

5929

6030
class SessionModule(Protocol):

0 commit comments

Comments
 (0)