|
1 | 1 | from __future__ import annotations |
2 | 2 | from typing import TYPE_CHECKING, List, Protocol |
3 | 3 |
|
| 4 | +from functools import lru_cache |
4 | 5 | import json |
5 | | -import threading |
6 | 6 | from typing import Any, Dict, Optional, Union, cast |
7 | 7 | import jwt |
8 | 8 | from jwt import PyJWKClient |
|
22 | 22 | from workos.user_management import AsyncUserManagement, UserManagement |
23 | 23 |
|
24 | 24 |
|
25 | | -class _JWKSClientCache: |
26 | | - def __init__(self) -> None: |
27 | | - self._clients: Dict[str, PyJWKClient] = {} |
28 | | - self._lock = threading.Lock() |
29 | | - |
30 | | - def get_client(self, jwks_url: str) -> PyJWKClient: |
31 | | - if jwks_url in self._clients: |
32 | | - return self._clients[jwks_url] |
33 | | - |
34 | | - with self._lock: |
35 | | - if jwks_url in self._clients: |
36 | | - return self._clients[jwks_url] |
37 | | - |
38 | | - client = PyJWKClient(jwks_url) |
39 | | - self._clients[jwks_url] = client |
40 | | - return client |
41 | | - |
42 | | - def clear(self) -> None: |
43 | | - """Intended primarily for test cleanup and manual cache invalidation. |
44 | | - |
45 | | - Warning: If called concurrently with get_client(), some newly created |
46 | | - clients might be lost due to lock acquisition ordering. |
47 | | - """ |
48 | | - with self._lock: |
49 | | - self._clients.clear() |
50 | | - |
51 | | - |
52 | | -# Module-level cache instance |
53 | | -_jwks_cache = _JWKSClientCache() |
54 | | - |
55 | | - |
| 25 | +@lru_cache(maxsize=None) |
56 | 26 | def _get_jwks_client(jwks_url: str) -> PyJWKClient: |
57 | | - return _jwks_cache.get_client(jwks_url) |
| 27 | + return PyJWKClient(jwks_url) |
58 | 28 |
|
59 | 29 |
|
60 | 30 | class SessionModule(Protocol): |
|
0 commit comments