Skip to content

Commit 36c94d5

Browse files
committed
use hybrid approach
1 parent 3b33d1a commit 36c94d5

File tree

5 files changed

+185
-26
lines changed

5 files changed

+185
-26
lines changed

tests/aio/test_credentials.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import json
88
import asyncio
9-
from unittest.mock import patch, AsyncMock
9+
from unittest.mock import patch, AsyncMock, MagicMock
1010

1111
import tests.auth.test_credentials
1212
import tests.oauth2_token_exchange
@@ -125,6 +125,8 @@ async def test_token_lazy_refresh():
125125
"localhost:0",
126126
)
127127

128+
credentials._tp.submit = MagicMock()
129+
128130
mock_response = {"access_token": "token_v1", "expires_in": 3600}
129131
credentials._make_token_request = AsyncMock(return_value=mock_response)
130132

@@ -139,7 +141,7 @@ async def test_token_lazy_refresh():
139141
assert token2 == "token_v1"
140142
assert credentials._make_token_request.call_count == 1
141143

142-
mock_time.return_value = 2000
144+
mock_time.return_value = 1000 + 3600 - 30 + 1
143145
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
144146

145147
token3 = await credentials.token()
@@ -156,6 +158,8 @@ async def test_token_double_check_locking():
156158
"localhost:0",
157159
)
158160

161+
credentials._tp.submit = MagicMock()
162+
159163
call_count = 0
160164

161165
async def mock_make_request():
@@ -185,14 +189,16 @@ async def test_token_expiration_calculation():
185189
"localhost:0",
186190
)
187191

192+
credentials._tp.submit = MagicMock()
193+
188194
with patch("time.time") as mock_time:
189195
mock_time.return_value = 1000
190196

191197
credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})
192198

193199
await credentials.token()
194200

195-
expected_expires = 1000 + min(1800, 3600 / 4)
201+
expected_expires = 1000 + 3600 - 30
196202
assert credentials._expires_in == expected_expires
197203

198204

@@ -205,10 +211,55 @@ async def test_token_refresh_error_handling():
205211
"localhost:0",
206212
)
207213

214+
credentials._tp.submit = MagicMock()
215+
208216
credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))
209217

210218
with pytest.raises(Exception) as exc_info:
211219
await credentials.token()
212220

213221
assert "Network error" in str(exc_info.value)
214222
assert credentials.last_error == "Network error"
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_hybrid_background_and_sync_refresh():
227+
credentials = ServiceAccountCredentialsForTest(
228+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
229+
tests.auth.test_credentials.ACCESS_KEY_ID,
230+
tests.auth.test_credentials.PRIVATE_KEY,
231+
"localhost:0",
232+
)
233+
234+
call_count = 0
235+
background_calls = []
236+
237+
async def mock_make_request():
238+
nonlocal call_count
239+
call_count += 1
240+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
241+
242+
def mock_submit(callback):
243+
background_calls.append(callback)
244+
245+
credentials._make_token_request = mock_make_request
246+
credentials._tp.submit = mock_submit
247+
248+
with patch("time.time") as mock_time:
249+
mock_time.return_value = 1000
250+
251+
token1 = await credentials.token()
252+
assert token1 == "token_v1"
253+
assert call_count == 1
254+
assert len(background_calls) == 0
255+
256+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
257+
token2 = await credentials.token()
258+
assert token2 == "token_v1"
259+
assert call_count == 1
260+
assert len(background_calls) == 1
261+
262+
mock_time.return_value = 1000 + 3600 - 30 + 1
263+
token3 = await credentials.token()
264+
assert token3 == "token_v2"
265+
assert call_count == 2

tests/auth/test_static_credentials.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def test_static_credentials_wrong_creds(endpoint, database):
5151
def test_token_lazy_refresh():
5252
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
5353

54+
credentials._tp.submit = MagicMock()
55+
5456
mock_response = {"access_token": "token_v1", "expires_in": 3600}
5557
credentials._make_token_request = MagicMock(return_value=mock_response)
5658

@@ -65,7 +67,7 @@ def test_token_lazy_refresh():
6567
assert token2 == "token_v1"
6668
assert credentials._make_token_request.call_count == 1
6769

68-
mock_time.return_value = 2000
70+
mock_time.return_value = 1000 + 3600 - 30 + 1
6971
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
7072

7173
token3 = credentials.token
@@ -75,6 +77,7 @@ def test_token_lazy_refresh():
7577

7678
def test_token_double_check_locking():
7779
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
80+
credentials._tp.submit = MagicMock()
7881

7982
call_count = 0
8083

@@ -108,24 +111,66 @@ def get_token():
108111
def test_token_expiration_calculation():
109112
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
110113

114+
credentials._tp.submit = MagicMock()
115+
111116
with patch("time.time") as mock_time:
112117
mock_time.return_value = 1000
113118

114119
credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600})
115120

116121
credentials.token
117122

118-
expected_expires = 1000 + min(1800, 3600 / 4)
123+
expected_expires = 1000 + 3600 - 30
119124
assert credentials._expires_in == expected_expires
120125

121126

122127
def test_token_refresh_error_handling():
123128
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
124-
129+
credentials._tp.submit = MagicMock()
125130
credentials._make_token_request = MagicMock(side_effect=Exception("Network error"))
126131

127-
with pytest.raises(ydb.ConnectionError) as exc_info:
128-
credentials.token
132+
with patch("time.time") as mock_time:
133+
mock_time.return_value = 1000 + 3600
134+
135+
with pytest.raises(ydb.ConnectionError) as exc_info:
136+
credentials.token
137+
138+
assert "Network error" in str(exc_info.value)
139+
assert credentials.last_error == "Network error"
140+
141+
142+
def test_hybrid_background_and_sync_refresh():
143+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
144+
145+
call_count = 0
146+
background_calls = []
147+
148+
def mock_make_request():
149+
nonlocal call_count
150+
call_count += 1
151+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
129152

130-
assert "Network error" in str(exc_info.value)
131-
assert credentials.last_error == "Network error"
153+
def mock_submit(callback):
154+
background_calls.append(callback)
155+
156+
credentials._make_token_request = mock_make_request
157+
credentials._tp.submit = mock_submit
158+
159+
with patch("time.time") as mock_time:
160+
mock_time.return_value = 1000
161+
162+
token1 = credentials.token
163+
assert token1 == "token_v1"
164+
assert call_count == 1
165+
assert len(background_calls) == 0
166+
167+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
168+
token2 = credentials.token
169+
assert token2 == "token_v1"
170+
assert call_count == 1
171+
assert len(background_calls) == 1
172+
173+
mock_time.return_value = 1000 + 3600 - 30 + 1
174+
token3 = credentials.token
175+
assert token3 == "token_v2"
176+
assert call_count == 2

ydb/aio/credentials.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
import logging
44
import time
5+
from concurrent import futures
56

67
from ydb import credentials
78
from ydb import issues
@@ -10,10 +11,31 @@
1011
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
1112

1213

14+
class AtMostOneExecution(object):
15+
def __init__(self):
16+
self._can_schedule = True
17+
self._lock = asyncio.Lock()
18+
19+
async def wrapped_execution(self, callback):
20+
async with self._lock:
21+
try:
22+
res = await callback()
23+
except Exception:
24+
pass
25+
finally:
26+
self._can_schedule = True
27+
28+
def submit(self, callback):
29+
if self._can_schedule:
30+
self._can_schedule = False
31+
asyncio.create_task(self.wrapped_execution(callback))
32+
33+
1334
class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials):
1435
def __init__(self):
1536
super(AbstractExpiringTokenCredentials, self).__init__()
1637
self._token_lock = asyncio.Lock()
38+
self._tp = AtMostOneExecution()
1739

1840
@abc.abstractmethod
1941
async def _make_token_request(self):
@@ -25,14 +47,12 @@ async def get_auth_token(self) -> str:
2547
return token
2648
return ""
2749

28-
async def _refresh_token(self):
50+
async def _refresh_token(self, should_raise=False):
2951
current_time = time.time()
3052

3153
try:
3254
self.logger.debug(
33-
"Refreshing token async, current_time: %s, expires_in: %s",
34-
current_time,
35-
self._expires_in,
55+
"Refreshing token async, current_time: %s, expires_in: %s", current_time, self._expires_in
3656
)
3757

3858
token_response = await self._make_token_request()
@@ -44,19 +64,23 @@ async def _refresh_token(self):
4464
except Exception as e:
4565
self.last_error = str(e)
4666
self.logger.error("Failed to refresh token async: %s", e)
47-
raise issues.ConnectionError(
48-
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
49-
)
67+
if should_raise:
68+
raise issues.ConnectionError(
69+
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
70+
)
5071

5172
async def token(self):
5273
if self._is_token_valid():
74+
if self._should_refresh():
75+
self._tp.submit(self._refresh_token)
76+
5377
return self._cached_token
5478

5579
async with self._token_lock:
5680
if self._is_token_valid():
5781
return self._cached_token
5882

59-
await self._refresh_token()
83+
await self._refresh_token(should_raise=True)
6084

6185
return self._cached_token
6286

ydb/aio/iam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(self, metadata_url=None):
103103
assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider"
104104
self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url
105105
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."
106+
self._tp.submit(self._refresh_token)
106107

107108
async def _make_token_request(self):
108109
timeout = aiohttp.ClientTimeout(total=2)

0 commit comments

Comments
 (0)