2121 from workos .user_management import AsyncUserManagement , UserManagement
2222
2323
24+ class _JWKSClientCache :
25+ def __init__ (self ) -> None :
26+ self ._clients : Dict [str , PyJWKClient ] = {}
27+
28+ def get_client (self , jwks_url : str ) -> PyJWKClient :
29+ if jwks_url not in self ._clients :
30+ self ._clients [jwks_url ] = PyJWKClient (jwks_url )
31+ return self ._clients [jwks_url ]
32+
33+
34+ # Module-level cache instance
35+ _jwks_cache = _JWKSClientCache ()
36+
37+
38+ def _get_jwks_client (jwks_url : str ) -> PyJWKClient :
39+ return _jwks_cache .get_client (jwks_url )
40+
41+
2442class SessionModule (Protocol ):
2543 user_management : "UserManagementModule"
2644 client_id : str
@@ -46,7 +64,7 @@ def __init__(
4664 self .session_data = session_data
4765 self .cookie_password = cookie_password
4866
49- self .jwks = PyJWKClient (self .user_management .get_jwks_url ())
67+ self .jwks = _get_jwks_client (self .user_management .get_jwks_url ())
5068
5169 # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
5270 self .jwk_algorithms = ["RS256" ]
@@ -164,7 +182,7 @@ def __init__(
164182 self .session_data = session_data
165183 self .cookie_password = cookie_password
166184
167- self .jwks = PyJWKClient (self .user_management .get_jwks_url ())
185+ self .jwks = _get_jwks_client (self .user_management .get_jwks_url ())
168186
169187 # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
170188 self .jwk_algorithms = ["RS256" ]
@@ -254,7 +272,7 @@ def __init__(
254272 self .session_data = session_data
255273 self .cookie_password = cookie_password
256274
257- self .jwks = PyJWKClient (self .user_management .get_jwks_url ())
275+ self .jwks = _get_jwks_client (self .user_management .get_jwks_url ())
258276
259277 # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
260278 self .jwk_algorithms = ["RS256" ]
0 commit comments