Skip to content

Commit cdb71eb

Browse files
committed
use hybrid approach
1 parent 3b33d1a commit cdb71eb

File tree

5 files changed

+184
-26
lines changed

5 files changed

+184
-26
lines changed

tests/aio/test_credentials.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import json
88
import asyncio
9-
from unittest.mock import patch, AsyncMock
9+
from unittest.mock import patch, AsyncMock, MagicMock
1010

1111
import tests.auth.test_credentials
1212
import tests.oauth2_token_exchange
@@ -125,6 +125,8 @@ async def test_token_lazy_refresh():
125125
"localhost:0",
126126
)
127127

128+
credentials._tp.submit = MagicMock()
129+
128130
mock_response = {"access_token": "token_v1", "expires_in": 3600}
129131
credentials._make_token_request = AsyncMock(return_value=mock_response)
130132

@@ -139,7 +141,7 @@ async def test_token_lazy_refresh():
139141
assert token2 == "token_v1"
140142
assert credentials._make_token_request.call_count == 1
141143

142-
mock_time.return_value = 2000
144+
mock_time.return_value = 1000 + 3600 - 30 + 1
143145
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
144146

145147
token3 = await credentials.token()
@@ -156,6 +158,8 @@ async def test_token_double_check_locking():
156158
"localhost:0",
157159
)
158160

161+
credentials._tp.submit = MagicMock()
162+
159163
call_count = 0
160164

161165
async def mock_make_request():
@@ -185,14 +189,16 @@ async def test_token_expiration_calculation():
185189
"localhost:0",
186190
)
187191

192+
credentials._tp.submit = MagicMock()
193+
188194
with patch("time.time") as mock_time:
189195
mock_time.return_value = 1000
190196

191197
credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})
192198

193199
await credentials.token()
194200

195-
expected_expires = 1000 + min(1800, 3600 / 4)
201+
expected_expires = 1000 + 3600 - 30
196202
assert credentials._expires_in == expected_expires
197203

198204

@@ -205,10 +211,55 @@ async def test_token_refresh_error_handling():
205211
"localhost:0",
206212
)
207213

214+
credentials._tp.submit = MagicMock()
215+
208216
credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))
209217

210218
with pytest.raises(Exception) as exc_info:
211219
await credentials.token()
212220

213221
assert "Network error" in str(exc_info.value)
214222
assert credentials.last_error == "Network error"
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_hybrid_background_and_sync_refresh():
227+
credentials = ServiceAccountCredentialsForTest(
228+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
229+
tests.auth.test_credentials.ACCESS_KEY_ID,
230+
tests.auth.test_credentials.PRIVATE_KEY,
231+
"localhost:0",
232+
)
233+
234+
call_count = 0
235+
background_calls = []
236+
237+
async def mock_make_request():
238+
nonlocal call_count
239+
call_count += 1
240+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
241+
242+
def mock_submit(callback):
243+
background_calls.append(callback)
244+
245+
credentials._make_token_request = mock_make_request
246+
credentials._tp.submit = mock_submit
247+
248+
with patch("time.time") as mock_time:
249+
mock_time.return_value = 1000
250+
251+
token1 = await credentials.token()
252+
assert token1 == "token_v1"
253+
assert call_count == 1
254+
assert len(background_calls) == 0
255+
256+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
257+
token2 = await credentials.token()
258+
assert token2 == "token_v1"
259+
assert call_count == 1
260+
assert len(background_calls) == 1
261+
262+
mock_time.return_value = 1000 + 3600 - 30 + 1
263+
token3 = await credentials.token()
264+
assert token3 == "token_v2"
265+
assert call_count == 2

tests/auth/test_static_credentials.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def test_static_credentials_wrong_creds(endpoint, database):
5151
def test_token_lazy_refresh():
5252
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
5353

54+
credentials._tp.submit = MagicMock()
55+
5456
mock_response = {"access_token": "token_v1", "expires_in": 3600}
5557
credentials._make_token_request = MagicMock(return_value=mock_response)
5658

@@ -65,7 +67,7 @@ def test_token_lazy_refresh():
6567
assert token2 == "token_v1"
6668
assert credentials._make_token_request.call_count == 1
6769

68-
mock_time.return_value = 2000
70+
mock_time.return_value = 1000 + 3600 - 30 + 1
6971
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
7072

7173
token3 = credentials.token
@@ -75,6 +77,7 @@ def test_token_lazy_refresh():
7577

7678
def test_token_double_check_locking():
7779
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
80+
credentials._tp.submit = MagicMock()
7881

7982
call_count = 0
8083

@@ -108,24 +111,66 @@ def get_token():
108111
def test_token_expiration_calculation():
109112
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
110113

114+
credentials._tp.submit = MagicMock()
115+
111116
with patch("time.time") as mock_time:
112117
mock_time.return_value = 1000
113118

114119
credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600})
115120

116121
credentials.token
117122

118-
expected_expires = 1000 + min(1800, 3600 / 4)
123+
expected_expires = 1000 + 3600 - 30
119124
assert credentials._expires_in == expected_expires
120125

121126

122127
def test_token_refresh_error_handling():
123128
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
124-
129+
credentials._tp.submit = MagicMock()
125130
credentials._make_token_request = MagicMock(side_effect=Exception("Network error"))
126131

127-
with pytest.raises(ydb.ConnectionError) as exc_info:
128-
credentials.token
132+
with patch("time.time") as mock_time:
133+
mock_time.return_value = 1000 + 3600
134+
135+
with pytest.raises(ydb.ConnectionError) as exc_info:
136+
credentials.token
137+
138+
assert "Network error" in str(exc_info.value)
139+
assert credentials.last_error == "Network error"
140+
141+
142+
def test_hybrid_background_and_sync_refresh():
143+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
144+
145+
call_count = 0
146+
background_calls = []
147+
148+
def mock_make_request():
149+
nonlocal call_count
150+
call_count += 1
151+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
129152

130-
assert "Network error" in str(exc_info.value)
131-
assert credentials.last_error == "Network error"
153+
def mock_submit(callback):
154+
background_calls.append(callback)
155+
156+
credentials._make_token_request = mock_make_request
157+
credentials._tp.submit = mock_submit
158+
159+
with patch("time.time") as mock_time:
160+
mock_time.return_value = 1000
161+
162+
token1 = credentials.token
163+
assert token1 == "token_v1"
164+
assert call_count == 1
165+
assert len(background_calls) == 0
166+
167+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
168+
token2 = credentials.token
169+
assert token2 == "token_v1"
170+
assert call_count == 1
171+
assert len(background_calls) == 1
172+
173+
mock_time.return_value = 1000 + 3600 - 30 + 1
174+
token3 = credentials.token
175+
assert token3 == "token_v2"
176+
assert call_count == 2

ydb/aio/credentials.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,31 @@
1010
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
1111

1212

13+
class AtMostOneExecution(object):
14+
def __init__(self):
15+
self._can_schedule = True
16+
self._lock = asyncio.Lock()
17+
18+
async def wrapped_execution(self, callback):
19+
async with self._lock:
20+
try:
21+
await callback()
22+
except Exception:
23+
pass
24+
finally:
25+
self._can_schedule = True
26+
27+
def submit(self, callback):
28+
if self._can_schedule:
29+
self._can_schedule = False
30+
asyncio.create_task(self.wrapped_execution(callback))
31+
32+
1333
class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials):
1434
def __init__(self):
1535
super(AbstractExpiringTokenCredentials, self).__init__()
1636
self._token_lock = asyncio.Lock()
37+
self._tp = AtMostOneExecution()
1738

1839
@abc.abstractmethod
1940
async def _make_token_request(self):
@@ -25,14 +46,12 @@ async def get_auth_token(self) -> str:
2546
return token
2647
return ""
2748

28-
async def _refresh_token(self):
49+
async def _refresh_token(self, should_raise=False):
2950
current_time = time.time()
3051

3152
try:
3253
self.logger.debug(
33-
"Refreshing token async, current_time: %s, expires_in: %s",
34-
current_time,
35-
self._expires_in,
54+
"Refreshing token async, current_time: %s, expires_in: %s", current_time, self._expires_in
3655
)
3756

3857
token_response = await self._make_token_request()
@@ -44,19 +63,23 @@ async def _refresh_token(self):
4463
except Exception as e:
4564
self.last_error = str(e)
4665
self.logger.error("Failed to refresh token async: %s", e)
47-
raise issues.ConnectionError(
48-
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
49-
)
66+
if should_raise:
67+
raise issues.ConnectionError(
68+
"%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message)
69+
)
5070

5171
async def token(self):
5272
if self._is_token_valid():
73+
if self._should_refresh():
74+
self._tp.submit(self._refresh_token)
75+
5376
return self._cached_token
5477

5578
async with self._token_lock:
5679
if self._is_token_valid():
5780
return self._cached_token
5881

59-
await self._refresh_token()
82+
await self._refresh_token(should_raise=True)
6083

6184
return self._cached_token
6285

ydb/aio/iam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(self, metadata_url=None):
103103
assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider"
104104
self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url
105105
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."
106+
self._tp.submit(self._refresh_token)
106107

107108
async def _make_token_request(self):
108109
timeout = aiohttp.ClientTimeout(total=2)

0 commit comments

Comments
 (0)