@@ -810,7 +810,7 @@ class PubSub:
810
810
"""
811
811
812
812
PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" )
813
- UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" )
813
+ UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" , "sunsubscribe" )
814
814
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
815
815
816
816
def __init__ (
@@ -852,6 +852,8 @@ def __init__(
852
852
self .pending_unsubscribe_channels = set ()
853
853
self .patterns = {}
854
854
self .pending_unsubscribe_patterns = set ()
855
+ self .shard_channels = {}
856
+ self .pending_unsubscribe_shard_channels = set ()
855
857
self ._lock = asyncio .Lock ()
856
858
857
859
async def __aenter__ (self ):
@@ -880,6 +882,8 @@ async def aclose(self):
880
882
self .pending_unsubscribe_channels = set ()
881
883
self .patterns = {}
882
884
self .pending_unsubscribe_patterns = set ()
885
+ self .shard_channels = {}
886
+ self .pending_unsubscribe_shard_channels = set ()
883
887
884
888
@deprecated_function (version = "5.0.1" , reason = "Use aclose() instead" , name = "close" )
885
889
async def close (self ) -> None :
@@ -898,6 +902,7 @@ async def on_connect(self, connection: Connection):
898
902
# before passing them to [p]subscribe.
899
903
self .pending_unsubscribe_channels .clear ()
900
904
self .pending_unsubscribe_patterns .clear ()
905
+ self .pending_unsubscribe_shard_channels .clear ()
901
906
if self .channels :
902
907
channels = {}
903
908
for k , v in self .channels .items ():
@@ -908,11 +913,17 @@ async def on_connect(self, connection: Connection):
908
913
for k , v in self .patterns .items ():
909
914
patterns [self .encoder .decode (k , force = True )] = v
910
915
await self .psubscribe (** patterns )
916
+ if self .shard_channels :
917
+ shard_channels = {
918
+ self .encoder .decode (k , force = True ): v
919
+ for k , v in self .shard_channels .items ()
920
+ }
921
+ await self .ssubscribe (** shard_channels )
911
922
912
923
@property
913
924
def subscribed (self ):
914
925
"""Indicates if there are subscriptions to any channels or patterns"""
915
- return bool (self .channels or self .patterns )
926
+ return bool (self .channels or self .patterns or self . shard_channels )
916
927
917
928
async def execute_command (self , * args : EncodableT ):
918
929
"""Execute a publish/subscribe command"""
@@ -1091,6 +1102,40 @@ def unsubscribe(self, *args) -> Awaitable:
1091
1102
self .pending_unsubscribe_channels .update (channels )
1092
1103
return self .execute_command ("UNSUBSCRIBE" , * parsed_args )
1093
1104
1105
+ def ssubscribe (self , * args , target_node = None , ** kwargs ) -> Awaitable :
1106
+ """
1107
+ Subscribes the client to the specified shard channels.
1108
+ Channels supplied as keyword arguments expect a channel name as the key
1109
+ and a callable as the value. A channel's callable will be invoked automatically
1110
+ when a message is received on that channel rather than producing a message via
1111
+ ``listen()`` or ``get_sharded_message()``.
1112
+ """
1113
+ if args :
1114
+ args = list_or_args (args [0 ], args [1 :])
1115
+ new_s_channels = dict .fromkeys (args )
1116
+ new_s_channels .update (kwargs )
1117
+ ret_val = self .execute_command ("SSUBSCRIBE" , * new_s_channels .keys ())
1118
+ # update the s_channels dict AFTER we send the command. we don't want to
1119
+ # subscribe twice to these channels, once for the command and again
1120
+ # for the reconnection.
1121
+ new_s_channels = self ._normalize_keys (new_s_channels )
1122
+ self .shard_channels .update (new_s_channels )
1123
+ self .pending_unsubscribe_shard_channels .difference_update (new_s_channels )
1124
+ return ret_val
1125
+
1126
+ def sunsubscribe (self , * args , target_node = None ) -> Awaitable :
1127
+ """
1128
+ Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1129
+ all shard_channels
1130
+ """
1131
+ if args :
1132
+ args = list_or_args (args [0 ], args [1 :])
1133
+ s_channels = self ._normalize_keys (dict .fromkeys (args ))
1134
+ else :
1135
+ s_channels = self .shard_channels
1136
+ self .pending_unsubscribe_shard_channels .update (s_channels )
1137
+ return self .execute_command ("SUNSUBSCRIBE" , * args )
1138
+
1094
1139
async def listen (self ) -> AsyncIterator :
1095
1140
"""Listen for messages on channels this client has been subscribed to"""
1096
1141
while self .subscribed :
@@ -1160,6 +1205,11 @@ async def handle_message(self, response, ignore_subscribe_messages=False):
1160
1205
if pattern in self .pending_unsubscribe_patterns :
1161
1206
self .pending_unsubscribe_patterns .remove (pattern )
1162
1207
self .patterns .pop (pattern , None )
1208
+ elif message_type == "sunsubscribe" :
1209
+ s_channel = response [1 ]
1210
+ if s_channel in self .pending_unsubscribe_shard_channels :
1211
+ self .pending_unsubscribe_shard_channels .remove (s_channel )
1212
+ self .shard_channels .pop (s_channel , None )
1163
1213
else :
1164
1214
channel = response [1 ]
1165
1215
if channel in self .pending_unsubscribe_channels :
@@ -1172,6 +1222,8 @@ async def handle_message(self, response, ignore_subscribe_messages=False):
1172
1222
handler = self .patterns .get (message ["pattern" ], None )
1173
1223
else :
1174
1224
handler = self .channels .get (message ["channel" ], None )
1225
+ if handler is None :
1226
+ handler = self .shard_channels .get (message ["channel" ], None )
1175
1227
if handler :
1176
1228
if inspect .iscoroutinefunction (handler ):
1177
1229
await handler (message )
0 commit comments