Skip to content

Commit 4527bf0

Browse files
committed
Added error propagation to main thread
1 parent a9c200c commit 4527bf0

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

redis/event.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,10 @@ def listen(self, event: AfterPooledConnectionsInstantiationEvent):
247247

248248
if event.client_type == ClientType.SYNC:
249249
event.credential_provider.on_next(self._re_auth)
250+
event.credential_provider.on_error(self._raise_on_error)
250251
else:
251252
event.credential_provider.on_next(self._re_auth_async)
253+
event.credential_provider.on_error(self._raise_on_error_async)
252254

253255
def _re_auth(self, token):
254256
for pool in self._event.connection_pools:
@@ -258,6 +260,12 @@ async def _re_auth_async(self, token):
258260
for pool in self._event.connection_pools:
259261
await pool.re_auth_callback(token)
260262

263+
def _raise_on_error(self, error: Exception):
264+
raise error
265+
266+
async def _raise_on_error_async(self, error: Exception):
267+
raise error
268+
261269

262270
class RegisterReAuthForSingleConnection(EventListenerInterface):
263271
"""
@@ -273,8 +281,10 @@ def listen(self, event: AfterSingleConnectionInstantiationEvent):
273281

274282
if event.client_type == ClientType.SYNC:
275283
event.connection.credential_provider.on_next(self._re_auth)
284+
event.connection.credential_provider.on_error(self._raise_on_error)
276285
else:
277286
event.connection.credential_provider.on_next(self._re_auth_async)
287+
event.connection.credential_provider.on_error(self._raise_on_error_async)
278288

279289
def _re_auth(self, token):
280290
with self._event.connection_lock:
@@ -286,6 +296,12 @@ async def _re_auth_async(self, token):
286296
await self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value())
287297
await self._event.connection.read_response()
288298

299+
def _raise_on_error(self, error: Exception):
300+
raise error
301+
302+
async def _raise_on_error_async(self, error: Exception):
303+
raise error
304+
289305

290306
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
291307
def __init__(self):
@@ -295,11 +311,15 @@ def listen(self, event: AfterAsyncClusterInstantiationEvent):
295311
if isinstance(event.credential_provider, StreamingCredentialProvider):
296312
self._event = event
297313
event.credential_provider.on_next(self._re_auth)
314+
event.credential_provider.on_error(self._raise_on_error)
298315

299316
async def _re_auth(self, token: TokenInterface):
300317
for key in self._event.nodes:
301318
await self._event.nodes[key].re_auth_callback(token)
302319

320+
async def _raise_on_error(self, error: Exception):
321+
raise error
322+
303323

304324
class RegisterReAuthForPubSub(EventListenerInterface):
305325
def __init__(self):
@@ -320,8 +340,10 @@ def listen(self, event: AfterPubSubConnectionInstantiationEvent):
320340

321341
if self._client_type == ClientType.SYNC:
322342
self._connection.credential_provider.on_next(self._re_auth)
343+
self._connection.credential_provider.on_error(self._raise_on_error)
323344
else:
324345
self._connection.credential_provider.on_next(self._re_auth_async)
346+
self._connection.credential_provider.on_error(self._raise_on_error_async)
325347

326348
def _re_auth(self, token: TokenInterface):
327349
with self._connection_lock:
@@ -336,3 +358,9 @@ async def _re_auth_async(self, token: TokenInterface):
336358
await self._connection.read_response()
337359

338360
await self._connection_pool.re_auth_callback(token)
361+
362+
def _raise_on_error(self, error: Exception):
363+
raise error
364+
365+
async def _raise_on_error_async(self, error: Exception):
366+
raise error

tests/test_asyncio/test_credentials.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from redis import AuthenticationError, DataError, ResponseError, RedisError
1515
from redis.asyncio import Redis, Connection, ConnectionPool
1616
from redis.asyncio.retry import Retry
17+
from redis.auth.err import RequestTokenErr
1718
from redis.exceptions import ConnectionError
1819
from redis.backoff import NoBackoff
1920
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
@@ -534,6 +535,42 @@ async def re_auth_callback(token):
534535
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
535536
])
536537

538+
@pytest.mark.parametrize(
539+
"credential_provider",
540+
[
541+
{
542+
"cred_provider_class": EntraIdCredentialsProvider,
543+
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
544+
"mock_idp": True,
545+
}
546+
],
547+
indirect=True,
548+
)
549+
async def test_fails_on_token_renewal(self, credential_provider):
550+
credential_provider._token_mgr._idp.request_token.side_effect = [
551+
RequestTokenErr,
552+
RequestTokenErr,
553+
RequestTokenErr,
554+
RequestTokenErr
555+
]
556+
mock_connection = Mock(spec=Connection)
557+
mock_connection.retry = Retry(NoBackoff(), 0)
558+
mock_another_connection = Mock(spec=Connection)
559+
mock_pool = Mock(spec=ConnectionPool)
560+
mock_pool.connection_kwargs = {
561+
"credential_provider": credential_provider,
562+
}
563+
mock_pool.get_connection.return_value = mock_connection
564+
mock_pool._available_connections = [mock_connection, mock_another_connection]
565+
566+
await Redis(
567+
connection_pool=mock_pool,
568+
credential_provider=credential_provider,
569+
)
570+
571+
with pytest.raises(RequestTokenErr):
572+
await credential_provider.get_credentials()
573+
537574

538575
@pytest.mark.asyncio
539576
@pytest.mark.onlynoncluster

tests/test_credentials.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from redis import AuthenticationError, DataError, ResponseError, Redis, asyncio
1818
from redis.asyncio import Redis as AsyncRedis, Connection
1919
from redis.asyncio import ConnectionPool as AsyncConnectionPool
20+
from redis.auth.err import RequestTokenErr
2021
from redis.auth.idp import IdentityProviderInterface
2122
from redis.exceptions import ConnectionError, RedisError
2223
from redis.backoff import NoBackoff
@@ -512,6 +513,43 @@ def re_auth_callback(token):
512513
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
513514
])
514515

516+
@pytest.mark.parametrize(
517+
"credential_provider",
518+
[
519+
{
520+
"cred_provider_class": EntraIdCredentialsProvider,
521+
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
522+
"mock_idp": True,
523+
}
524+
],
525+
indirect=True,
526+
)
527+
def test_fails_on_token_renewal(self, credential_provider):
528+
credential_provider._token_mgr._idp.request_token.side_effect = [
529+
RequestTokenErr,
530+
RequestTokenErr,
531+
RequestTokenErr,
532+
RequestTokenErr
533+
]
534+
mock_connection = Mock(spec=ConnectionInterface)
535+
mock_connection.retry = Retry(NoBackoff(), 0)
536+
mock_another_connection = Mock(spec=ConnectionInterface)
537+
mock_pool = Mock(spec=ConnectionPool)
538+
mock_pool.connection_kwargs = {
539+
"credential_provider": credential_provider,
540+
}
541+
mock_pool.get_connection.return_value = mock_connection
542+
mock_pool._available_connections = [mock_connection, mock_another_connection]
543+
mock_pool._lock = threading.Lock()
544+
545+
Redis(
546+
connection_pool=mock_pool,
547+
credential_provider=credential_provider,
548+
)
549+
550+
with pytest.raises(RequestTokenErr):
551+
credential_provider.get_credentials()
552+
515553

516554
@pytest.mark.onlynoncluster
517555
@pytest.mark.cp_integration

0 commit comments

Comments
 (0)