Skip to content

Commit ef1563d

Browse files
committed
Refactor auth token refresh logic
1 parent 40ac692 commit ef1563d

File tree

6 files changed

+247
-179
lines changed

6 files changed

+247
-179
lines changed

tests/aio/test_credentials.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import tempfile
66
import os
77
import json
8+
import asyncio
9+
from unittest.mock import patch, AsyncMock
810

911
import tests.auth.test_credentials
1012
import tests.oauth2_token_exchange
@@ -112,3 +114,101 @@ def serve(s):
112114
except Exception:
113115
os.remove(cfg_file_name)
114116
raise
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_token_lazy_refresh():
121+
credentials = ServiceAccountCredentialsForTest(
122+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
123+
tests.auth.test_credentials.ACCESS_KEY_ID,
124+
tests.auth.test_credentials.PRIVATE_KEY,
125+
"localhost:0",
126+
)
127+
128+
mock_response = {"access_token": "token_v1", "expires_in": 3600}
129+
credentials._make_token_request = AsyncMock(return_value=mock_response)
130+
131+
with patch("time.time") as mock_time:
132+
mock_time.return_value = 1000
133+
134+
token1 = await credentials.token()
135+
assert token1 == "token_v1"
136+
assert credentials._make_token_request.call_count == 1
137+
138+
token2 = await credentials.token()
139+
assert token2 == "token_v1"
140+
assert credentials._make_token_request.call_count == 1
141+
142+
mock_time.return_value = 2000
143+
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
144+
145+
token3 = await credentials.token()
146+
assert token3 == "token_v2"
147+
assert credentials._make_token_request.call_count == 2
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_token_double_check_locking():
152+
credentials = ServiceAccountCredentialsForTest(
153+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
154+
tests.auth.test_credentials.ACCESS_KEY_ID,
155+
tests.auth.test_credentials.PRIVATE_KEY,
156+
"localhost:0",
157+
)
158+
159+
call_count = 0
160+
161+
async def mock_make_request():
162+
nonlocal call_count
163+
call_count += 1
164+
await asyncio.sleep(0.01)
165+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
166+
167+
credentials._make_token_request = mock_make_request
168+
169+
with patch("time.time") as mock_time:
170+
mock_time.return_value = 1000
171+
172+
tasks = [credentials.token() for _ in range(10)]
173+
results = await asyncio.gather(*tasks)
174+
175+
assert len(set(results)) == 1
176+
assert call_count == 1
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_token_expiration_calculation():
181+
credentials = ServiceAccountCredentialsForTest(
182+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
183+
tests.auth.test_credentials.ACCESS_KEY_ID,
184+
tests.auth.test_credentials.PRIVATE_KEY,
185+
"localhost:0",
186+
)
187+
188+
with patch("time.time") as mock_time:
189+
mock_time.return_value = 1000
190+
191+
credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})
192+
193+
await credentials.token()
194+
195+
expected_expires = 1000 + min(1800, 3600 / 4)
196+
assert credentials._expires_in == expected_expires
197+
198+
199+
@pytest.mark.asyncio
200+
async def test_token_refresh_error_handling():
201+
credentials = ServiceAccountCredentialsForTest(
202+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
203+
tests.auth.test_credentials.ACCESS_KEY_ID,
204+
tests.auth.test_credentials.PRIVATE_KEY,
205+
"localhost:0",
206+
)
207+
208+
credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))
209+
210+
with pytest.raises(Exception) as exc_info:
211+
await credentials.token()
212+
213+
assert "Network error" in str(exc_info.value)
214+
assert credentials.last_error == "Network error"

tests/auth/test_static_credentials.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import ydb
3+
from unittest.mock import patch, MagicMock
34

45

56
USERNAME = "root"
@@ -45,3 +46,86 @@ def test_static_credentials_wrong_creds(endpoint, database):
4546
with pytest.raises(ydb.ConnectionFailure):
4647
with ydb.Driver(driver_config=driver_config) as driver:
4748
driver.wait(5, fail_fast=True)
49+
50+
51+
def test_token_lazy_refresh():
52+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
53+
54+
mock_response = {"access_token": "token_v1", "expires_in": 3600}
55+
credentials._make_token_request = MagicMock(return_value=mock_response)
56+
57+
with patch("time.time") as mock_time:
58+
mock_time.return_value = 1000
59+
60+
token1 = credentials.token
61+
assert token1 == "token_v1"
62+
assert credentials._make_token_request.call_count == 1
63+
64+
token2 = credentials.token
65+
assert token2 == "token_v1"
66+
assert credentials._make_token_request.call_count == 1
67+
68+
mock_time.return_value = 2000
69+
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
70+
71+
token3 = credentials.token
72+
assert token3 == "token_v2"
73+
assert credentials._make_token_request.call_count == 2
74+
75+
76+
def test_token_double_check_locking():
77+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
78+
79+
call_count = 0
80+
81+
def mock_make_request():
82+
nonlocal call_count
83+
call_count += 1
84+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
85+
86+
credentials._make_token_request = mock_make_request
87+
88+
with patch("time.time") as mock_time:
89+
mock_time.return_value = 1000
90+
91+
import threading
92+
93+
results = []
94+
95+
def get_token():
96+
results.append(credentials.token)
97+
98+
threads = [threading.Thread(target=get_token) for _ in range(10)]
99+
for t in threads:
100+
t.start()
101+
for t in threads:
102+
t.join()
103+
104+
assert len(set(results)) == 1
105+
assert call_count == 1
106+
107+
108+
def test_token_expiration_calculation():
109+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
110+
111+
with patch("time.time") as mock_time:
112+
mock_time.return_value = 1000
113+
114+
credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600})
115+
116+
credentials.token
117+
118+
expected_expires = 1000 + min(1800, 3600 / 4)
119+
assert credentials._expires_in == expected_expires
120+
121+
122+
def test_token_refresh_error_handling():
123+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
124+
125+
credentials._make_token_request = MagicMock(side_effect=Exception("Network error"))
126+
127+
with pytest.raises(ydb.ConnectionError) as exc_info:
128+
credentials.token
129+
130+
assert "Network error" in str(exc_info.value)
131+
assert credentials.last_error == "Network error"

ydb/aio/credentials.py

Lines changed: 27 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,57 +10,10 @@
1010
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
1111

1212

13-
class _OneToManyValue(object):
14-
def __init__(self):
15-
self._value = None
16-
self._condition = asyncio.Condition()
17-
18-
async def consume(self, timeout=3):
19-
async with self._condition:
20-
if self._value is None:
21-
try:
22-
await asyncio.wait_for(self._condition.wait(), timeout=timeout)
23-
except Exception:
24-
return self._value
25-
return self._value
26-
27-
async def update(self, n_value):
28-
async with self._condition:
29-
prev_value = self._value
30-
self._value = n_value
31-
if prev_value is None:
32-
self._condition.notify_all()
33-
34-
35-
class _AtMostOneExecution(object):
36-
def __init__(self):
37-
self._can_schedule = True
38-
self._lock = asyncio.Lock() # Lock to guarantee only one execution
39-
40-
async def _wrapped_execution(self, callback):
41-
await self._lock.acquire()
42-
try:
43-
res = callback()
44-
if asyncio.iscoroutine(res):
45-
await res
46-
except Exception:
47-
pass
48-
49-
finally:
50-
self._lock.release()
51-
self._can_schedule = True
52-
53-
def submit(self, callback):
54-
if self._can_schedule:
55-
self._can_schedule = False
56-
asyncio.ensure_future(self._wrapped_execution(callback))
57-
58-
5913
class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials):
6014
def __init__(self):
6115
super(AbstractExpiringTokenCredentials, self).__init__()
62-
self._tp = _AtMostOneExecution()
63-
self._cached_token = _OneToManyValue()
16+
self._token_lock = asyncio.Lock()
6417

6518
@abc.abstractmethod
6619
async def _make_token_request(self):
@@ -72,51 +25,45 @@ async def get_auth_token(self) -> str:
7225
return token
7326
return ""
7427

75-
async def _refresh(self):
28+
async def _refresh_token(self):
7629
current_time = time.time()
77-
self._log_refresh_start(current_time)
7830

7931
try:
80-
auth_metadata = await self._make_token_request()
81-
await self._cached_token.update(auth_metadata["access_token"])
82-
self._update_expiration_info(auth_metadata)
83-
self.logger.info(
84-
"Token refresh successful. current_time %s, refresh_in %s",
32+
self.logger.debug(
33+
"Refreshing token async, current_time: %s, expires_in: %s",
8534
current_time,
86-
self._refresh_in,
35+
self._expires_in,
8736
)
8837

89-
except (KeyboardInterrupt, SystemExit):
90-
return
38+
token_response = await self._make_token_request()
39+
self._update_token_info(token_response, current_time)
9140

92-
except Exception as e:
93-
self.last_error = str(e)
94-
await asyncio.sleep(1)
95-
self._tp.submit(self._refresh)
41+
self.logger.info("Token refreshed successfully async, expires_in: %s", self._expires_in)
42+
self.last_error = None
9643

97-
except BaseException as e:
44+
except Exception as e:
9845
self.last_error = str(e)
99-
raise
46+
self.logger.error("Failed to refresh token async: %s", e)
47+
raise issues.ConnectionError(
48+
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
49+
)
10050

10151
async def token(self):
102-
current_time = time.time()
103-
if current_time > self._refresh_in:
104-
self._tp.submit(self._refresh)
105-
106-
cached_token = await self._cached_token.consume(timeout=3)
107-
if cached_token is None:
108-
if self.last_error is None:
109-
raise issues.ConnectionError(
110-
"%s: timeout occurred while waiting for token.\n%s"
111-
% (
112-
self.__class__.__name__,
113-
self.extra_error_message,
114-
)
115-
)
52+
if self._is_token_valid():
53+
return self._cached_token
54+
55+
async with self._token_lock:
56+
if self._is_token_valid():
57+
return self._cached_token
58+
59+
await self._refresh_token()
60+
61+
if self._cached_token is None:
11662
raise issues.ConnectionError(
117-
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
63+
"%s: No token available.\n%s" % (self.__class__.__name__, self.extra_error_message)
11864
)
119-
return cached_token
65+
66+
return self._cached_token
12067

12168
async def auth_metadata(self):
12269
return [(credentials.YDB_AUTH_TICKET_HEADER, await self.token())]

ydb/aio/iam.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def __init__(self, metadata_url=None):
102102
super(MetadataUrlCredentials, self).__init__()
103103
assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider"
104104
self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url
105-
self._tp.submit(self._refresh)
106105
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."
107106

108107
async def _make_token_request(self):

0 commit comments

Comments
 (0)