Skip to content

Commit 32432a1

Browse files
Sync monorepo state at "[python sdk] Add support for caching access token across client instances" (#107)
Syncing from userclouds/userclouds@d444b579997500a1ff5124d961d68cbab6279cc4
1 parent 75d81b3 commit 32432a1

File tree

6 files changed

+76
-60
lines changed

6 files changed

+76
-60
lines changed

CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Changelog
22

3-
## UNPUBLISHED
3+
## 1.8.0 - UNPUBLISHED
44

55
- Breaking change: Add "ending_before" argument to all paginated methods, add pagination to ExecuteAccessor, change "starting_after" and "ending_before" to be str instead of uuid
6+
- Add the ability to cache the access token globally (in process) and share it across clients instances. Pass the optional `use_global_cache_for_token=true` to the client in order to enable.
67

7-
## 1.7.0
8+
## 1.7.0 - 09-05-2024
89

910
- Update userstore sample to exercise partial update columns
1011
- Add methods for creating, retrieving, updating, and deleting ColumnDataTypes
@@ -13,7 +14,7 @@
1314
- Add ColumnDataType resource IDs for native ColumnDataTypes
1415
- Update userstore_sample.py to interact with ColumnDataTypes
1516

16-
## 1.6.1
17+
## 1.6.1 - Not published
1718

1819
- Add SDK method for data import via ExecuteMutator
1920

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
black>=24.2.0
1+
black>=24.4.2
22
ruff
33
pyupgrade
44
isort

src/usercloudssdk/asyncclient.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33
import asyncio
44
import base64
5-
import time
65
import urllib.parse
76
import uuid
87
from dataclasses import asdict
98
from pathlib import Path
109

11-
import jwt
12-
1310
from . import ucjson
1411
from .client_helpers import _SDK_VERSION, _id_from_identical_conflict, _read_env
1512
from .constants import _JSON_CONTENT_TYPE, AuthnType, Region
@@ -37,6 +34,7 @@
3734
UpdateColumnRetentionDurationsRequest,
3835
UserResponse,
3936
)
37+
from .token import cache_token, get_cached_token, is_token_expiring
4038
from .uchttpclient import create_default_uc_http_async_client
4139

4240

@@ -60,6 +58,7 @@ def __init__(
6058
client_secret: str,
6159
client_factory=create_default_uc_http_async_client,
6260
session_name: str | None = None,
61+
use_global_cache_for_token: bool = False,
6362
**kwargs,
6463
):
6564
self._authorization = base64.b64encode(
@@ -68,9 +67,9 @@ def __init__(
6867
"ISO-8859-1",
6968
)
7069
).decode("ascii")
71-
7270
self._client = client_factory(base_url=url, **kwargs)
7371
self._access_token: str | None = None # lazy loaded
72+
self._use_global_cache_for_token = use_global_cache_for_token
7473
self._access_token_lock = asyncio.Lock()
7574
base_ua = f"UserClouds Python SDK v{_SDK_VERSION}"
7675
self._common_headers = {
@@ -1126,6 +1125,10 @@ async def CheckDataImportStatusAsync(self, import_id: uuid.UUID) -> dict:
11261125
# Access Token Helpers
11271126

11281127
async def _get_access_token_async(self) -> str:
1128+
if self._use_global_cache_for_token:
1129+
token = get_cached_token(self._authorization)
1130+
if token:
1131+
return token
11291132
# Encode the client ID and client secret
11301133
headers = {
11311134
"Authorization": f"Basic {self._authorization}",
@@ -1142,7 +1145,10 @@ async def _get_access_token_async(self) -> str:
11421145
if resp.status_code >= 400:
11431146
raise UserCloudsSDKError.from_response(resp)
11441147
json_data = ucjson.loads(resp.text)
1145-
return json_data.get("access_token")
1148+
token = json_data["access_token"]
1149+
if self._use_global_cache_for_token:
1150+
cache_token(self._authorization, token)
1151+
return token
11461152

11471153
async def _refresh_access_token_if_needed_async(self) -> None:
11481154
if self._access_token is None:
@@ -1151,23 +1157,9 @@ async def _refresh_access_token_if_needed_async(self) -> None:
11511157
self._access_token = await self._get_access_token_async()
11521158
return
11531159

1154-
# TODO: this takes advantage of an implementation detail that we use JWTs for
1155-
# access tokens, but we should probably either expose an endpoint to verify
1156-
# expiration time, or expect to retry requests with a well-formed error, or
1157-
# change our bearer token format in time.
1158-
if (
1159-
jwt.decode(self._access_token, options={"verify_signature": False}).get(
1160-
"exp"
1161-
)
1162-
< time.time()
1163-
):
1160+
if is_token_expiring(self._access_token):
11641161
async with self._access_token_lock:
1165-
if (
1166-
jwt.decode(
1167-
self._access_token, options={"verify_signature": False}
1168-
).get("exp")
1169-
< time.time()
1170-
):
1162+
if is_token_expiring(self._access_token):
11711163
self._access_token = await self._get_access_token_async()
11721164

11731165
# Request Helpers

src/usercloudssdk/client.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
from __future__ import annotations
22

33
import base64
4-
import time
54
import urllib.parse
65
import uuid
76
from dataclasses import asdict
87
from pathlib import Path
98

10-
import jwt
11-
129
from . import ucjson
1310
from .client_helpers import _SDK_VERSION, _id_from_identical_conflict, _read_env
1411
from .constants import _JSON_CONTENT_TYPE, AuthnType, Region
@@ -36,6 +33,7 @@
3633
UpdateColumnRetentionDurationsRequest,
3734
UserResponse,
3835
)
36+
from .token import cache_token, get_cached_token, is_token_expiring
3937
from .uchttpclient import create_default_uc_http_client
4038

4139

@@ -59,6 +57,7 @@ def __init__(
5957
client_secret: str,
6058
client_factory=create_default_uc_http_client,
6159
session_name: str | None = None,
60+
use_global_cache_for_token: bool = False,
6261
**kwargs,
6362
):
6463
self._authorization = base64.b64encode(
@@ -70,6 +69,7 @@ def __init__(
7069

7170
self._client = client_factory(base_url=url, **kwargs)
7271
self._access_token: str | None = None # lazy loaded
72+
self._use_global_cache_for_token = use_global_cache_for_token
7373
base_ua = f"UserClouds Python SDK v{_SDK_VERSION}"
7474
self._common_headers = {
7575
"User-Agent": f"{base_ua} [{session_name}]" if session_name else base_ua,
@@ -1082,6 +1082,10 @@ def CheckDataImportStatus(self, import_id: uuid.UUID) -> dict:
10821082
# Access Token Helpers
10831083

10841084
def _get_access_token(self) -> str:
1085+
if self._use_global_cache_for_token:
1086+
token = get_cached_token(self._authorization)
1087+
if token:
1088+
return token
10851089
# Encode the client ID and client secret
10861090
headers = {
10871091
"Authorization": f"Basic {self._authorization}",
@@ -1096,23 +1100,17 @@ def _get_access_token(self) -> str:
10961100
if resp.status_code >= 400:
10971101
raise UserCloudsSDKError.from_response(resp)
10981102
json_data = ucjson.loads(resp.text)
1099-
return json_data.get("access_token")
1103+
token = json_data.get("access_token")
1104+
if self._use_global_cache_for_token:
1105+
cache_token(self._authorization, token)
1106+
return token
11001107

11011108
def _refresh_access_token_if_needed(self) -> None:
11021109
if self._access_token is None:
11031110
self._access_token = self._get_access_token()
11041111
return
11051112

1106-
# TODO: this takes advantage of an implementation detail that we use JWTs for
1107-
# access tokens, but we should probably either expose an endpoint to verify
1108-
# expiration time, or expect to retry requests with a well-formed error, or
1109-
# change our bearer token format in time.
1110-
if (
1111-
jwt.decode(self._access_token, options={"verify_signature": False}).get(
1112-
"exp"
1113-
)
1114-
< time.time()
1115-
):
1113+
if is_token_expiring(self._access_token):
11161114
self._access_token = self._get_access_token()
11171115

11181116
# Request Helpers

src/usercloudssdk/token.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import jwt
6+
7+
_token_cache: dict[str, str] = {}
8+
9+
10+
def get_cached_token(authorization: str) -> str | None:
11+
global _token_cache
12+
return _token_cache.get(authorization, None)
13+
14+
15+
def cache_token(authorization: str, token: str) -> None:
16+
if not token:
17+
return
18+
global _token_cache
19+
_token_cache[authorization] = token
20+
21+
22+
def is_token_expiring(token: str | None) -> bool:
23+
if not token:
24+
return True
25+
# TODO: this takes advantage of an implementation detail that we use JWTs for
26+
# access tokens, but we should probably either expose an endpoint to verify
27+
# expiration time, or expect to retry requests with a well-formed error, or
28+
# change our bearer token format in time.
29+
decoded_token = jwt.decode(token, options={"verify_signature": False})
30+
expiration_time = decoded_token.get("exp")
31+
if not expiration_time:
32+
return True
33+
return expiration_time > time.time()

tools/genasyncclient.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88
99
import asyncio
1010
import base64
11-
import time
1211
import urllib.parse
1312
import uuid
1413
from dataclasses import asdict
1514
from pathlib import Path
1615
17-
import jwt
18-
1916
from . import ucjson
2017
from .client_helpers import _SDK_VERSION, _id_from_identical_conflict, _read_env
2118
from .constants import _JSON_CONTENT_TYPE, AuthnType, Region
@@ -43,6 +40,7 @@
4340
UpdateColumnRetentionDurationsRequest,
4441
UserResponse,
4542
)
43+
from .token import cache_token, get_cached_token, is_token_expiring
4644
from .uchttpclient import create_default_uc_http_async_client
4745
4846
@@ -64,6 +62,7 @@ def __init__(
6462
client_secret: str,
6563
client_factory=create_default_uc_http_async_client,
6664
session_name: str | None = None,
65+
use_global_cache_for_token: bool = False,
6766
**kwargs,
6867
):
6968
self._authorization = base64.b64encode(
@@ -72,9 +71,9 @@ def __init__(
7271
"ISO-8859-1",
7372
)
7473
).decode("ascii")
75-
7674
self._client = client_factory(base_url=url, **kwargs)
7775
self._access_token: str | None = None # lazy loaded
76+
self._use_global_cache_for_token = use_global_cache_for_token
7877
self._access_token_lock = asyncio.Lock()
7978
base_ua = f"UserClouds Python SDK v{_SDK_VERSION}"
8079
self._common_headers = {
@@ -85,6 +84,10 @@ def __init__(
8584

8685
non_auto_generated_footer = """
8786
async def _get_access_token_async(self) -> str:
87+
if self._use_global_cache_for_token:
88+
token = get_cached_token(self._authorization)
89+
if token:
90+
return token
8891
# Encode the client ID and client secret
8992
headers = {
9093
"Authorization": f"Basic {self._authorization}",
@@ -101,7 +104,10 @@ async def _get_access_token_async(self) -> str:
101104
if resp.status_code >= 400:
102105
raise UserCloudsSDKError.from_response(resp)
103106
json_data = ucjson.loads(resp.text)
104-
return json_data.get("access_token")
107+
token = json_data["access_token"]
108+
if self._use_global_cache_for_token:
109+
cache_token(self._authorization, token)
110+
return token
105111
106112
async def _refresh_access_token_if_needed_async(self) -> None:
107113
if self._access_token is None:
@@ -110,23 +116,9 @@ async def _refresh_access_token_if_needed_async(self) -> None:
110116
self._access_token = await self._get_access_token_async()
111117
return
112118
113-
# TODO: this takes advantage of an implementation detail that we use JWTs for
114-
# access tokens, but we should probably either expose an endpoint to verify
115-
# expiration time, or expect to retry requests with a well-formed error, or
116-
# change our bearer token format in time.
117-
if (
118-
jwt.decode(self._access_token, options={"verify_signature": False}).get(
119-
"exp"
120-
)
121-
< time.time()
122-
):
119+
if is_token_expiring(self._access_token):
123120
async with self._access_token_lock:
124-
if (
125-
jwt.decode(
126-
self._access_token, options={"verify_signature": False}
127-
).get("exp")
128-
< time.time()
129-
):
121+
if is_token_expiring(self._access_token):
130122
self._access_token = await self._get_access_token_async()
131123
132124
# Request Helpers

0 commit comments

Comments
 (0)