Skip to content

Commit fdb9075

Browse files
Async Connection: Allow PubSub.run() without previous subscribe() (#2148)
1 parent 696d984 commit fdb9075

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

redis/asyncio/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,16 +693,24 @@ async def execute_command(self, *args: EncodableT):
693693
# legitimate message off the stack if the connection is already
694694
# subscribed to one or more channels
695695

696+
await self.connect()
697+
connection = self.connection
698+
kwargs = {"check_health": not self.subscribed}
699+
await self._execute(connection, connection.send_command, *args, **kwargs)
700+
701+
async def connect(self):
702+
"""
703+
Ensure that the PubSub is connected
704+
"""
696705
if self.connection is None:
697706
self.connection = await self.connection_pool.get_connection(
698707
"pubsub", self.shard_hint
699708
)
700709
# register a callback that re-subscribes to any channels we
701710
# were listening to when we were disconnected
702711
self.connection.register_connect_callback(self.on_connect)
703-
connection = self.connection
704-
kwargs = {"check_health": not self.subscribed}
705-
await self._execute(connection, connection.send_command, *args, **kwargs)
712+
else:
713+
await self.connection.connect()
706714

707715
async def _disconnect_raise_connect(self, conn, error):
708716
"""
@@ -962,6 +970,7 @@ async def run(
962970
if handler is None:
963971
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
964972

973+
await self.connect()
965974
while True:
966975
try:
967976
await self.get_message(

tests/test_asyncio/test_pubsub.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from typing import Optional
44

5+
import async_timeout
56
import pytest
67

78
if sys.version_info[0:2] == (3, 6):
@@ -658,3 +659,35 @@ def exception_handler_callback(e, pubsub) -> None:
658659
except asyncio.CancelledError:
659660
pass
660661
assert str(e) == "error"
662+
663+
async def test_late_subscribe(self, r: redis.Redis):
664+
def callback(message):
665+
messages.put_nowait(message)
666+
667+
messages = asyncio.Queue()
668+
p = r.pubsub()
669+
task = asyncio.get_event_loop().create_task(p.run())
670+
# wait until loop gets settled. Add a subscription
671+
await asyncio.sleep(0.1)
672+
await p.subscribe(foo=callback)
673+
# wait tof the subscribe to finish. Cannot use _subscribe() because
674+
# p.run() is already accepting messages
675+
await asyncio.sleep(0.1)
676+
await r.publish("foo", "bar")
677+
message = None
678+
try:
679+
async with async_timeout.timeout(0.1):
680+
message = await messages.get()
681+
except asyncio.TimeoutError:
682+
pass
683+
task.cancel()
684+
# we expect a cancelled error, not the Runtime error
685+
# ("did you forget to call subscribe()"")
686+
with pytest.raises(asyncio.CancelledError):
687+
await task
688+
assert message == {
689+
"channel": b"foo",
690+
"data": b"bar",
691+
"pattern": None,
692+
"type": "message",
693+
}

0 commit comments

Comments
 (0)