@@ -833,6 +833,7 @@ class AbstractRedis:
833
833
"QUIT" : bool_ok ,
834
834
"STRALGO" : parse_stralgo ,
835
835
"PUBSUB NUMSUB" : parse_pubsub_numsub ,
836
+ "PUBSUB SHARDNUMSUB" : parse_pubsub_numsub ,
836
837
"RANDOMKEY" : lambda r : r and r or None ,
837
838
"RESET" : str_if_bytes ,
838
839
"SCAN" : parse_scan ,
@@ -1440,8 +1441,8 @@ class PubSub:
1440
1441
will be returned and it's safe to start listening again.
1441
1442
"""
1442
1443
1443
- PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" )
1444
- UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" )
1444
+ PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" , "smessage" )
1445
+ UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" , "sunsubscribe" )
1445
1446
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
1446
1447
1447
1448
def __init__ (
@@ -1493,9 +1494,11 @@ def reset(self):
1493
1494
self .connection .clear_connect_callbacks ()
1494
1495
self .connection_pool .release (self .connection )
1495
1496
self .connection = None
1496
- self .channels = {}
1497
1497
self .health_check_response_counter = 0
1498
+ self .channels = {}
1498
1499
self .pending_unsubscribe_channels = set ()
1500
+ self .shard_channels = {}
1501
+ self .pending_unsubscribe_shard_channels = set ()
1499
1502
self .patterns = {}
1500
1503
self .pending_unsubscribe_patterns = set ()
1501
1504
self .subscribed_event .clear ()
@@ -1510,16 +1513,23 @@ def on_connect(self, connection):
1510
1513
# before passing them to [p]subscribe.
1511
1514
self .pending_unsubscribe_channels .clear ()
1512
1515
self .pending_unsubscribe_patterns .clear ()
1516
+ self .pending_unsubscribe_shard_channels .clear ()
1513
1517
if self .channels :
1514
- channels = {}
1515
- for k , v in self .channels .items ():
1516
- channels [ self . encoder . decode ( k , force = True )] = v
1518
+ channels = {
1519
+ self . encoder . decode ( k , force = True ): v for k , v in self .channels .items ()
1520
+ }
1517
1521
self .subscribe (** channels )
1518
1522
if self .patterns :
1519
- patterns = {}
1520
- for k , v in self .patterns .items ():
1521
- patterns [ self . encoder . decode ( k , force = True )] = v
1523
+ patterns = {
1524
+ self . encoder . decode ( k , force = True ): v for k , v in self .patterns .items ()
1525
+ }
1522
1526
self .psubscribe (** patterns )
1527
+ if self .shard_channels :
1528
+ shard_channels = {
1529
+ self .encoder .decode (k , force = True ): v
1530
+ for k , v in self .shard_channels .items ()
1531
+ }
1532
+ self .ssubscribe (** shard_channels )
1523
1533
1524
1534
@property
1525
1535
def subscribed (self ):
@@ -1728,6 +1738,45 @@ def unsubscribe(self, *args):
1728
1738
self .pending_unsubscribe_channels .update (channels )
1729
1739
return self .execute_command ("UNSUBSCRIBE" , * args )
1730
1740
1741
+ def ssubscribe (self , * args , target_node = None , ** kwargs ):
1742
+ """
1743
+ Subscribes the client to the specified shard channels.
1744
+ Channels supplied as keyword arguments expect a channel name as the key
1745
+ and a callable as the value. A channel's callable will be invoked automatically
1746
+ when a message is received on that channel rather than producing a message via
1747
+ ``listen()`` or ``get_sharded_message()``.
1748
+ """
1749
+ if args :
1750
+ args = list_or_args (args [0 ], args [1 :])
1751
+ new_s_channels = dict .fromkeys (args )
1752
+ new_s_channels .update (kwargs )
1753
+ ret_val = self .execute_command ("SSUBSCRIBE" , * new_s_channels .keys ())
1754
+ # update the s_channels dict AFTER we send the command. we don't want to
1755
+ # subscribe twice to these channels, once for the command and again
1756
+ # for the reconnection.
1757
+ new_s_channels = self ._normalize_keys (new_s_channels )
1758
+ self .shard_channels .update (new_s_channels )
1759
+ if not self .subscribed :
1760
+ # Set the subscribed_event flag to True
1761
+ self .subscribed_event .set ()
1762
+ # Clear the health check counter
1763
+ self .health_check_response_counter = 0
1764
+ self .pending_unsubscribe_shard_channels .difference_update (new_s_channels )
1765
+ return ret_val
1766
+
1767
+ def sunsubscribe (self , * args , target_node = None ):
1768
+ """
1769
+ Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1770
+ all shard_channels
1771
+ """
1772
+ if args :
1773
+ args = list_or_args (args [0 ], args [1 :])
1774
+ s_channels = self ._normalize_keys (dict .fromkeys (args ))
1775
+ else :
1776
+ s_channels = self .shard_channels
1777
+ self .pending_unsubscribe_shard_channels .update (s_channels )
1778
+ return self .execute_command ("SUNSUBSCRIBE" , * args )
1779
+
1731
1780
def listen (self ):
1732
1781
"Listen for messages on channels this client has been subscribed to"
1733
1782
while self .subscribed :
@@ -1762,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
1762
1811
return self .handle_message (response , ignore_subscribe_messages )
1763
1812
return None
1764
1813
1814
+ get_sharded_message = get_message
1815
+
1765
1816
def ping (self , message = None ):
1766
1817
"""
1767
1818
Ping the Redis server
@@ -1809,12 +1860,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1809
1860
if pattern in self .pending_unsubscribe_patterns :
1810
1861
self .pending_unsubscribe_patterns .remove (pattern )
1811
1862
self .patterns .pop (pattern , None )
1863
+ elif message_type == "sunsubscribe" :
1864
+ s_channel = response [1 ]
1865
+ if s_channel in self .pending_unsubscribe_shard_channels :
1866
+ self .pending_unsubscribe_shard_channels .remove (s_channel )
1867
+ self .shard_channels .pop (s_channel , None )
1812
1868
else :
1813
1869
channel = response [1 ]
1814
1870
if channel in self .pending_unsubscribe_channels :
1815
1871
self .pending_unsubscribe_channels .remove (channel )
1816
1872
self .channels .pop (channel , None )
1817
- if not self .channels and not self .patterns :
1873
+ if not self .channels and not self .patterns and not self . shard_channels :
1818
1874
# There are no subscriptions anymore, set subscribed_event flag
1819
1875
# to false
1820
1876
self .subscribed_event .clear ()
@@ -1823,6 +1879,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1823
1879
# if there's a message handler, invoke it
1824
1880
if message_type == "pmessage" :
1825
1881
handler = self .patterns .get (message ["pattern" ], None )
1882
+ elif message_type == "smessage" :
1883
+ handler = self .shard_channels .get (message ["channel" ], None )
1826
1884
else :
1827
1885
handler = self .channels .get (message ["channel" ], None )
1828
1886
if handler :
@@ -1843,6 +1901,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
1843
1901
for pattern , handler in self .patterns .items ():
1844
1902
if handler is None :
1845
1903
raise PubSubError (f"Pattern: '{ pattern } ' has no handler registered" )
1904
+ for s_channel , handler in self .shard_channels .items ():
1905
+ if handler is None :
1906
+ raise PubSubError (
1907
+ f"Shard Channel: '{ s_channel } ' has no handler registered"
1908
+ )
1846
1909
1847
1910
thread = PubSubWorkerThread (
1848
1911
self , sleep_time , daemon = daemon , exception_handler = exception_handler
0 commit comments