Skip to content

Commit 0de0f4d

Browse files
committed
Added support for Pub/Sub
1 parent 90204e7 commit 0de0f4d

File tree

7 files changed

+324
-5
lines changed

7 files changed

+324
-5
lines changed

redis/asyncio/client.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555
from redis.credentials import CredentialProvider
5656
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
57-
AfterSingleConnectionInstantiationEvent
57+
AfterSingleConnectionInstantiationEvent, AfterPubSubConnectionInstantiationEvent
5858
from redis.exceptions import (
5959
ConnectionError,
6060
ExecAbortError,
@@ -539,7 +539,7 @@ def pubsub(self, **kwargs) -> "PubSub":
539539
subscribe to channels and listen for messages that get published to
540540
them.
541541
"""
542-
return PubSub(self.connection_pool, **kwargs)
542+
return PubSub(self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs)
543543

544544
def monitor(self) -> "Monitor":
545545
return Monitor(self.connection_pool)
@@ -777,6 +777,7 @@ def __init__(
777777
ignore_subscribe_messages: bool = False,
778778
encoder=None,
779779
push_handler_func: Optional[Callable] = None,
780+
event_dispatcher: Optional["EventDispatcher"] = EventDispatcher(),
780781
):
781782
self.connection_pool = connection_pool
782783
self.shard_hint = shard_hint
@@ -804,6 +805,7 @@ def __init__(
804805
self.pending_unsubscribe_channels = set()
805806
self.patterns = {}
806807
self.pending_unsubscribe_patterns = set()
808+
self._event_dispatcher = event_dispatcher
807809
self._lock = asyncio.Lock()
808810

809811
async def __aenter__(self):
@@ -894,6 +896,15 @@ async def connect(self):
894896
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
895897
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
896898

899+
self._event_dispatcher.dispatch(
900+
AfterPubSubConnectionInstantiationEvent(
901+
self.connection,
902+
self.connection_pool,
903+
ClientType.ASYNC,
904+
self._lock
905+
)
906+
)
907+
897908
async def _disconnect_raise_connect(self, conn, error):
898909
"""
899910
Close the connection and raise an exception

redis/asyncio/connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ def _host_error(self) -> str:
334334
def _error_message(self, exception: BaseException) -> str:
335335
return format_error_message(self._host_error(), exception)
336336

337+
def get_protocol(self):
338+
return self.protocol
339+
337340
async def on_connect(self) -> None:
338341
"""Initialize the connection, authenticate and select a database"""
339342
self._parser.on_connect(self)

redis/client.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from redis.credentials import CredentialProvider
3131
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType, \
32-
AfterSingleConnectionInstantiationEvent
32+
AfterSingleConnectionInstantiationEvent, AfterPubSubConnectionInstantiationEvent
3333
from redis.exceptions import (
3434
ConnectionError,
3535
ExecAbortError,
@@ -332,6 +332,7 @@ def __init__(
332332
))
333333

334334
self.connection_pool = connection_pool
335+
self._event_dispatcher = event_dispatcher
335336

336337
if (cache_config or cache) and self.connection_pool.get_protocol() not in [
337338
3,
@@ -518,7 +519,7 @@ def pubsub(self, **kwargs):
518519
subscribe to channels and listen for messages that get published to
519520
them.
520521
"""
521-
return PubSub(self.connection_pool, **kwargs)
522+
return PubSub(self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs)
522523

523524
def monitor(self):
524525
return Monitor(self.connection_pool)
@@ -709,6 +710,7 @@ def __init__(
709710
ignore_subscribe_messages: bool = False,
710711
encoder: Optional["Encoder"] = None,
711712
push_handler_func: Union[None, Callable[[str], None]] = None,
713+
event_dispatcher: Optional["EventDispatcher"] = EventDispatcher(),
712714
):
713715
self.connection_pool = connection_pool
714716
self.shard_hint = shard_hint
@@ -719,6 +721,8 @@ def __init__(
719721
# to lookup channel and pattern names for callback handlers.
720722
self.encoder = encoder
721723
self.push_handler_func = push_handler_func
724+
self._event_dispatcher = event_dispatcher
725+
self._lock = threading.Lock()
722726
if self.encoder is None:
723727
self.encoder = self.connection_pool.get_encoder()
724728
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
@@ -809,11 +813,20 @@ def execute_command(self, *args):
809813
self.connection.register_connect_callback(self.on_connect)
810814
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
811815
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
816+
self._event_dispatcher.dispatch(
817+
AfterPubSubConnectionInstantiationEvent(
818+
self.connection,
819+
self.connection_pool,
820+
ClientType.SYNC,
821+
self._lock
822+
)
823+
)
812824
connection = self.connection
813825
kwargs = {"check_health": not self.subscribed}
814826
if not self.subscribed:
815827
self.clean_health_check_responses()
816-
self._execute(connection, connection.send_command, *args, **kwargs)
828+
with self._lock:
829+
self._execute(connection, connection.send_command, *args, **kwargs)
817830

818831
def clean_health_check_responses(self) -> None:
819832
"""

redis/connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def deregister_connect_callback(self, callback):
153153
def set_parser(self, parser_class):
154154
pass
155155

156+
@abstractmethod
157+
def get_protocol(self):
158+
pass
159+
156160
@abstractmethod
157161
def connect(self):
158162
pass

redis/event.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __init__(self):
5858
AfterSingleConnectionInstantiationEvent: [
5959
RegisterReAuthForSingleConnection()
6060
],
61+
AfterPubSubConnectionInstantiationEvent: [
62+
RegisterReAuthForPubSub()
63+
],
6164
AfterAsyncClusterInstantiationEvent: [
6265
RegisterReAuthForAsyncClusterNodes()
6366
],
@@ -158,6 +161,36 @@ def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]:
158161
return self._connection_lock
159162

160163

164+
class AfterPubSubConnectionInstantiationEvent:
165+
def __init__(
166+
self,
167+
pubsub_connection,
168+
connection_pool,
169+
client_type: ClientType,
170+
connection_lock: Union[threading.Lock, asyncio.Lock]
171+
):
172+
self._pubsub_connection = pubsub_connection
173+
self._connection_pool = connection_pool
174+
self._client_type = client_type
175+
self._connection_lock = connection_lock
176+
177+
@property
178+
def pubsub_connection(self):
179+
return self._pubsub_connection
180+
181+
@property
182+
def connection_pool(self):
183+
return self._connection_pool
184+
185+
@property
186+
def client_type(self) -> ClientType:
187+
return self._client_type
188+
189+
@property
190+
def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]:
191+
return self._connection_lock
192+
193+
161194
class AfterAsyncClusterInstantiationEvent:
162195
"""
163196
Event that will be fired after async cluster instance was created.
@@ -266,3 +299,40 @@ def listen(self, event: AfterAsyncClusterInstantiationEvent):
266299
async def _re_auth(self, token: TokenInterface):
267300
for key in self._event.nodes:
268301
await self._event.nodes[key].re_auth_callback(token)
302+
303+
304+
class RegisterReAuthForPubSub(EventListenerInterface):
305+
def __init__(self):
306+
self._connection = None
307+
self._connection_pool = None
308+
self._client_type = None
309+
self._connection_lock = None
310+
311+
def listen(self, event: AfterPubSubConnectionInstantiationEvent):
312+
if (
313+
isinstance(event.pubsub_connection.credential_provider, StreamingCredentialProvider)
314+
and event.pubsub_connection.get_protocol() in [3, "3"]
315+
):
316+
self._connection = event.pubsub_connection
317+
self._connection_pool = event.connection_pool
318+
self._client_type = event.client_type
319+
self._connection_lock = event.connection_lock
320+
321+
if self._client_type == ClientType.SYNC:
322+
self._connection.credential_provider.on_next(self._re_auth)
323+
else:
324+
self._connection.credential_provider.on_next(self._re_auth_async)
325+
326+
def _re_auth(self, token: TokenInterface):
327+
with self._connection_lock:
328+
self._connection.send_command('AUTH', token.try_get('oid'), token.get_value())
329+
self._connection.read_response()
330+
331+
self._connection_pool.re_auth_callback(token)
332+
333+
async def _re_auth_async(self, token: TokenInterface):
334+
async with self._connection_lock:
335+
await self._connection.send_command('AUTH', token.try_get('oid'), token.get_value())
336+
await self._connection.read_response()
337+
338+
await self._connection_pool.re_auth_callback(token)

tests/test_asyncio/test_credentials.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,115 @@ async def re_auth_callback(token):
414414
mock_another_connection.read_response.assert_has_calls([call()])
415415
mock_failed_connection.read_response.assert_has_calls([call(), call(), call()])
416416

417+
@pytest.mark.parametrize(
418+
"credential_provider",
419+
[
420+
{
421+
"cred_provider_class": EntraIdCredentialsProvider,
422+
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
423+
"mock_idp": True,
424+
}
425+
],
426+
indirect=True,
427+
)
428+
async def test_re_auth_pub_sub_in_resp3(self, credential_provider):
429+
mock_pubsub_connection = Mock(spec=Connection)
430+
mock_pubsub_connection.get_protocol.return_value = 3
431+
mock_pubsub_connection.credential_provider = credential_provider
432+
mock_pubsub_connection.retry = Retry(NoBackoff(), 3)
433+
mock_another_connection = Mock(spec=Connection)
434+
mock_another_connection.retry = Retry(NoBackoff(), 3)
435+
436+
mock_pool = Mock(spec=ConnectionPool)
437+
mock_pool.connection_kwargs = {
438+
"credential_provider": credential_provider,
439+
}
440+
mock_pool.get_connection.side_effect = [mock_pubsub_connection, mock_another_connection]
441+
mock_pool._available_connections = [mock_another_connection]
442+
mock_pool._lock = AsyncLock()
443+
auth_token = None
444+
445+
async def re_auth_callback(token):
446+
nonlocal auth_token
447+
auth_token = token
448+
async with mock_pool._lock:
449+
for conn in mock_pool._available_connections:
450+
await conn.send_command('AUTH', token.try_get('oid'), token.get_value())
451+
await conn.read_response()
452+
453+
mock_pool.re_auth_callback = re_auth_callback
454+
455+
r = Redis(
456+
connection_pool=mock_pool,
457+
credential_provider=credential_provider,
458+
)
459+
p = r.pubsub()
460+
await p.subscribe('test')
461+
await credential_provider.get_credentials_async()
462+
await async_sleep(0.5)
463+
464+
mock_pubsub_connection.send_command.assert_has_calls([
465+
call('SUBSCRIBE', 'test', check_health=True),
466+
call('AUTH', auth_token.try_get('oid'), auth_token.get_value()),
467+
])
468+
mock_another_connection.send_command.assert_has_calls([
469+
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
470+
])
471+
472+
@pytest.mark.parametrize(
473+
"credential_provider",
474+
[
475+
{
476+
"cred_provider_class": EntraIdCredentialsProvider,
477+
"cred_provider_kwargs": {"expiration_refresh_ratio": 0.00005},
478+
"mock_idp": True,
479+
}
480+
],
481+
indirect=True,
482+
)
483+
async def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider):
484+
mock_pubsub_connection = Mock(spec=Connection)
485+
mock_pubsub_connection.get_protocol.return_value = 2
486+
mock_pubsub_connection.credential_provider = credential_provider
487+
mock_pubsub_connection.retry = Retry(NoBackoff(), 3)
488+
mock_another_connection = Mock(spec=Connection)
489+
mock_another_connection.retry = Retry(NoBackoff(), 3)
490+
491+
mock_pool = Mock(spec=ConnectionPool)
492+
mock_pool.connection_kwargs = {
493+
"credential_provider": credential_provider,
494+
}
495+
mock_pool.get_connection.side_effect = [mock_pubsub_connection, mock_another_connection]
496+
mock_pool._available_connections = [mock_another_connection]
497+
mock_pool._lock = AsyncLock()
498+
auth_token = None
499+
500+
async def re_auth_callback(token):
501+
nonlocal auth_token
502+
auth_token = token
503+
async with mock_pool._lock:
504+
for conn in mock_pool._available_connections:
505+
await conn.send_command('AUTH', token.try_get('oid'), token.get_value())
506+
await conn.read_response()
507+
508+
mock_pool.re_auth_callback = re_auth_callback
509+
510+
r = Redis(
511+
connection_pool=mock_pool,
512+
credential_provider=credential_provider,
513+
)
514+
p = r.pubsub()
515+
await p.subscribe('test')
516+
await credential_provider.get_credentials_async()
517+
await async_sleep(0.5)
518+
519+
mock_pubsub_connection.send_command.assert_has_calls([
520+
call('SUBSCRIBE', 'test', check_health=True),
521+
])
522+
mock_another_connection.send_command.assert_has_calls([
523+
call('AUTH', auth_token.try_get('oid'), auth_token.get_value())
524+
])
525+
417526

418527
@pytest.mark.asyncio
419528
@pytest.mark.onlynoncluster

0 commit comments

Comments
 (0)