Skip to content

Commit ceecaee

Browse files
committed
Cache JWKS clients per URL
1 parent 10b70fa commit ceecaee

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

tests/test_session.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime, timezone
55

66
from tests.conftest import with_jwks_mock
7-
from workos.session import AsyncSession, Session
7+
from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache
88
from workos.types.user_management.authentication_response import (
99
RefreshTokenAuthenticationResponse,
1010
)
@@ -20,6 +20,12 @@
2020

2121

2222
class SessionFixtures:
23+
@pytest.fixture(autouse=True)
24+
def clear_jwks_cache(self):
25+
_jwks_cache._clients.clear()
26+
yield
27+
_jwks_cache._clients.clear()
28+
2329
@pytest.fixture
2430
def session_constants(self):
2531
# Generate RSA key pair for testing
@@ -491,3 +497,26 @@ async def test_refresh_success_with_aud_claim(
491497
response = await session.refresh()
492498

493499
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
500+
501+
502+
class TestJWKSCaching:
503+
def test_jwks_client_caching_same_url(self):
504+
url = "https://api.workos.com/sso/jwks/test"
505+
506+
client1 = _get_jwks_client(url)
507+
client2 = _get_jwks_client(url)
508+
509+
# Should be the exact same instance
510+
assert client1 is client2
511+
assert id(client1) == id(client2)
512+
513+
def test_jwks_client_caching_different_urls(self):
514+
url1 = "https://api.workos.com/sso/jwks/client1"
515+
url2 = "https://api.workos.com/sso/jwks/client2"
516+
517+
client1 = _get_jwks_client(url1)
518+
client2 = _get_jwks_client(url2)
519+
520+
# Should be different instances
521+
assert client1 is not client2
522+
assert id(client1) != id(client2)

workos/session.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@
2121
from workos.user_management import AsyncUserManagement, UserManagement
2222

2323

24+
class _JWKSClientCache:
25+
def __init__(self) -> None:
26+
self._clients: Dict[str, PyJWKClient] = {}
27+
28+
def get_client(self, jwks_url: str) -> PyJWKClient:
29+
if jwks_url not in self._clients:
30+
self._clients[jwks_url] = PyJWKClient(jwks_url)
31+
return self._clients[jwks_url]
32+
33+
34+
# Module-level cache instance
35+
_jwks_cache = _JWKSClientCache()
36+
37+
38+
def _get_jwks_client(jwks_url: str) -> PyJWKClient:
39+
return _jwks_cache.get_client(jwks_url)
40+
41+
2442
class SessionModule(Protocol):
2543
user_management: "UserManagementModule"
2644
client_id: str
@@ -46,7 +64,7 @@ def __init__(
4664
self.session_data = session_data
4765
self.cookie_password = cookie_password
4866

49-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
67+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
5068

5169
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
5270
self.jwk_algorithms = ["RS256"]
@@ -164,7 +182,7 @@ def __init__(
164182
self.session_data = session_data
165183
self.cookie_password = cookie_password
166184

167-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
185+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
168186

169187
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
170188
self.jwk_algorithms = ["RS256"]
@@ -254,7 +272,7 @@ def __init__(
254272
self.session_data = session_data
255273
self.cookie_password = cookie_password
256274

257-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
275+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
258276

259277
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
260278
self.jwk_algorithms = ["RS256"]

0 commit comments

Comments
 (0)