Skip to content

Commit 963843b

Browse files
Add unittest for PubSub.connect() (#2167)
* Add unittest for PubSub reconnect * fix linting
1 parent 5c99e27 commit 963843b

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

tests/test_asyncio/test_pubsub.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import functools
23
import sys
34
from typing import Optional
45

@@ -20,6 +21,18 @@
2021
pytestmark = pytest.mark.asyncio(forbid_global_loop=True)
2122

2223

24+
def with_timeout(t):
25+
def wrapper(corofunc):
26+
@functools.wraps(corofunc)
27+
async def run(*args, **kwargs):
28+
async with async_timeout.timeout(t):
29+
return await corofunc(*args, **kwargs)
30+
31+
return run
32+
33+
return wrapper
34+
35+
2336
async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
2437
now = asyncio.get_event_loop().time()
2538
timeout = now + timeout
@@ -603,6 +616,75 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis):
603616
assert await p.get_message(timeout=0.01) is None
604617

605618

619+
@pytest.mark.onlynoncluster
620+
class TestPubSubReconnect:
621+
# @pytest.mark.xfail
622+
@with_timeout(2)
623+
async def test_reconnect_listen(self, r: redis.Redis):
624+
"""
625+
Test that a loop processing PubSub messages can survive
626+
a disconnect, by issuing a connect() call.
627+
"""
628+
messages = asyncio.Queue()
629+
pubsub = r.pubsub()
630+
interrupt = False
631+
632+
async def loop():
633+
# must make sure the task exits
634+
async with async_timeout.timeout(2):
635+
nonlocal interrupt
636+
await pubsub.subscribe("foo")
637+
while True:
638+
# print("loop")
639+
try:
640+
try:
641+
await pubsub.connect()
642+
await loop_step()
643+
# print("succ")
644+
except redis.ConnectionError:
645+
await asyncio.sleep(0.1)
646+
except asyncio.CancelledError:
647+
# we use a cancel to interrupt the "listen"
648+
# when we perform a disconnect
649+
# print("cancel", interrupt)
650+
if interrupt:
651+
interrupt = False
652+
else:
653+
raise
654+
655+
async def loop_step():
656+
# get a single message via listen()
657+
async for message in pubsub.listen():
658+
await messages.put(message)
659+
break
660+
661+
task = asyncio.get_event_loop().create_task(loop())
662+
# get the initial connect message
663+
async with async_timeout.timeout(1):
664+
message = await messages.get()
665+
assert message == {
666+
"channel": b"foo",
667+
"data": 1,
668+
"pattern": None,
669+
"type": "subscribe",
670+
}
671+
# now, disconnect the connection.
672+
await pubsub.connection.disconnect()
673+
interrupt = True
674+
task.cancel() # interrupt the listen call
675+
# await another auto-connect message
676+
message = await messages.get()
677+
assert message == {
678+
"channel": b"foo",
679+
"data": 1,
680+
"pattern": None,
681+
"type": "subscribe",
682+
}
683+
task.cancel()
684+
with pytest.raises(asyncio.CancelledError):
685+
await task
686+
687+
606688
@pytest.mark.onlynoncluster
607689
class TestPubSubRun:
608690
async def _subscribe(self, p, *args, **kwargs):

0 commit comments

Comments
 (0)