Skip to content

Commit ebb3e59

Browse files
committed
Support Cluster PubSub in asyncio
1 parent 349d761 commit ebb3e59

File tree

4 files changed

+531
-3
lines changed

4 files changed

+531
-3
lines changed

redis/asyncio/client.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ class PubSub:
810810
"""
811811

812812
PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
813-
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
813+
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
814814
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
815815

816816
def __init__(
@@ -852,6 +852,8 @@ def __init__(
852852
self.pending_unsubscribe_channels = set()
853853
self.patterns = {}
854854
self.pending_unsubscribe_patterns = set()
855+
self.shard_channels = {}
856+
self.pending_unsubscribe_shard_channels = set()
855857
self._lock = asyncio.Lock()
856858

857859
async def __aenter__(self):
@@ -880,6 +882,8 @@ async def aclose(self):
880882
self.pending_unsubscribe_channels = set()
881883
self.patterns = {}
882884
self.pending_unsubscribe_patterns = set()
885+
self.shard_channels = {}
886+
self.pending_unsubscribe_shard_channels = set()
883887

884888
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
885889
async def close(self) -> None:
@@ -898,6 +902,7 @@ async def on_connect(self, connection: Connection):
898902
# before passing them to [p]subscribe.
899903
self.pending_unsubscribe_channels.clear()
900904
self.pending_unsubscribe_patterns.clear()
905+
self.pending_unsubscribe_shard_channels.clear()
901906
if self.channels:
902907
channels = {}
903908
for k, v in self.channels.items():
@@ -908,11 +913,17 @@ async def on_connect(self, connection: Connection):
908913
for k, v in self.patterns.items():
909914
patterns[self.encoder.decode(k, force=True)] = v
910915
await self.psubscribe(**patterns)
916+
if self.shard_channels:
917+
shard_channels = {
918+
self.encoder.decode(k, force=True): v
919+
for k, v in self.shard_channels.items()
920+
}
921+
await self.ssubscribe(**shard_channels)
911922

912923
@property
913924
def subscribed(self):
914925
"""Indicates if there are subscriptions to any channels or patterns"""
915-
return bool(self.channels or self.patterns)
926+
return bool(self.channels or self.patterns or self.shard_channels)
916927

917928
async def execute_command(self, *args: EncodableT):
918929
"""Execute a publish/subscribe command"""
@@ -1091,6 +1102,40 @@ def unsubscribe(self, *args) -> Awaitable:
10911102
self.pending_unsubscribe_channels.update(channels)
10921103
return self.execute_command("UNSUBSCRIBE", *parsed_args)
10931104

1105+
def ssubscribe(self, *args, target_node=None, **kwargs) -> Awaitable:
1106+
"""
1107+
Subscribes the client to the specified shard channels.
1108+
Channels supplied as keyword arguments expect a channel name as the key
1109+
and a callable as the value. A channel's callable will be invoked automatically
1110+
when a message is received on that channel rather than producing a message via
1111+
``listen()`` or ``get_sharded_message()``.
1112+
"""
1113+
if args:
1114+
args = list_or_args(args[0], args[1:])
1115+
new_s_channels = dict.fromkeys(args)
1116+
new_s_channels.update(kwargs)
1117+
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1118+
# update the s_channels dict AFTER we send the command. we don't want to
1119+
# subscribe twice to these channels, once for the command and again
1120+
# for the reconnection.
1121+
new_s_channels = self._normalize_keys(new_s_channels)
1122+
self.shard_channels.update(new_s_channels)
1123+
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1124+
return ret_val
1125+
1126+
def sunsubscribe(self, *args, target_node=None) -> Awaitable:
1127+
"""
1128+
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1129+
all shard_channels
1130+
"""
1131+
if args:
1132+
args = list_or_args(args[0], args[1:])
1133+
s_channels = self._normalize_keys(dict.fromkeys(args))
1134+
else:
1135+
s_channels = self.shard_channels
1136+
self.pending_unsubscribe_shard_channels.update(s_channels)
1137+
return self.execute_command("SUNSUBSCRIBE", *args)
1138+
10941139
async def listen(self) -> AsyncIterator:
10951140
"""Listen for messages on channels this client has been subscribed to"""
10961141
while self.subscribed:
@@ -1160,6 +1205,11 @@ async def handle_message(self, response, ignore_subscribe_messages=False):
11601205
if pattern in self.pending_unsubscribe_patterns:
11611206
self.pending_unsubscribe_patterns.remove(pattern)
11621207
self.patterns.pop(pattern, None)
1208+
elif message_type == "sunsubscribe":
1209+
s_channel = response[1]
1210+
if s_channel in self.pending_unsubscribe_shard_channels:
1211+
self.pending_unsubscribe_shard_channels.remove(s_channel)
1212+
self.shard_channels.pop(s_channel, None)
11631213
else:
11641214
channel = response[1]
11651215
if channel in self.pending_unsubscribe_channels:
@@ -1172,6 +1222,8 @@ async def handle_message(self, response, ignore_subscribe_messages=False):
11721222
handler = self.patterns.get(message["pattern"], None)
11731223
else:
11741224
handler = self.channels.get(message["channel"], None)
1225+
if handler is None:
1226+
handler = self.shard_channels.get(message["channel"], None)
11751227
if handler:
11761228
if inspect.iscoroutinefunction(handler):
11771229
await handler(message)

0 commit comments

Comments
 (0)