6363 RedisError ,
6464 TimeoutError ,
6565)
66+ from redis .utils import safe_str
6667
6768# Type alias for channel arguments - can be a string, bytes, or Channel object
6869# This is defined here and the actual types are added after class definitions
@@ -296,7 +297,7 @@ class KeyNotification:
296297 @classmethod
297298 def from_message (
298299 cls ,
299- message : Dict [str , Any ],
300+ message : Optional [ Dict [str , Any ] ],
300301 key_prefix : Optional [Union [str , bytes ]] = None ,
301302 ) -> Optional ["KeyNotification" ]:
302303 """
@@ -340,10 +341,8 @@ def from_message(
340341 return None
341342
342343 # Convert bytes to string if needed
343- if isinstance (channel , bytes ):
344- channel = channel .decode ("utf-8" , errors = "replace" )
345- if isinstance (data , bytes ):
346- data = data .decode ("utf-8" , errors = "replace" )
344+ channel = safe_str (channel )
345+ data = safe_str (data )
347346
348347 return cls ._parse (channel , data , key_prefix )
349348
@@ -368,10 +367,8 @@ def try_parse(
368367 Returns:
369368 A KeyNotification if valid, None otherwise.
370369 """
371- if isinstance (channel , bytes ):
372- channel = channel .decode ("utf-8" , errors = "replace" )
373- if isinstance (data , bytes ):
374- data = data .decode ("utf-8" , errors = "replace" )
370+ channel = safe_str (channel )
371+ data = safe_str (data )
375372
376373 return cls ._parse (channel , data , key_prefix )
377374
@@ -384,8 +381,7 @@ def _parse(
384381 ) -> Optional ["KeyNotification" ]:
385382 """Internal parsing logic."""
386383 # Normalize key_prefix
387- if isinstance (key_prefix , bytes ):
388- key_prefix = key_prefix .decode ("utf-8" , errors = "replace" )
384+ key_prefix = safe_str (key_prefix )
389385
390386 # Try keyspace pattern first: __keyspace@<db>__:<key>
391387 match = cls ._KEYSPACE_PATTERN .match (channel )
@@ -433,8 +429,7 @@ def _parse(
433429
434430 def key_starts_with (self , prefix : Union [str , bytes ]) -> bool :
435431 """Check if the key starts with the given prefix."""
436- if isinstance (prefix , bytes ):
437- prefix = prefix .decode ("utf-8" , errors = "replace" )
432+ prefix = safe_str (prefix )
438433 return self .key .startswith (prefix )
439434
440435
@@ -454,16 +449,17 @@ class KeyspaceChannel:
454449 __str__ to return the channel string.
455450
456451 Attributes:
457- key_or_pattern: The key or pattern to monitor
458- db: The database number (None means all databases )
452+ key_or_pattern: The key or pattern to monitor (use '*' for wildcards)
453+ db: The database number (defaults to 0, the only database in Redis Cluster )
459454 is_pattern: Whether this channel contains wildcards
460455
461456 Examples:
462457 >>> channel = KeyspaceChannel("user:123", db=0)
463458 >>> str(channel)
464459 '__keyspace@0__:user:123'
465460
466- >>> channel = KeyspaceChannel.pattern("user:", db=0)
461+ >>> # Pattern subscription (wildcards are auto-detected)
462+ >>> channel = KeyspaceChannel("user:*", db=0)
467463 >>> str(channel)
468464 '__keyspace@0__:user:*'
469465
@@ -599,15 +595,13 @@ def __hash__(self) -> int:
599595
600596def is_keyspace_channel (channel : Union [str , bytes ]) -> bool :
601597 """Check if a channel is a keyspace notification channel."""
602- if isinstance (channel , bytes ):
603- channel = channel .decode ("utf-8" , errors = "replace" )
598+ channel = safe_str (channel )
604599 return channel .startswith (KeyspaceChannel .PREFIX )
605600
606601
607602def is_keyevent_channel (channel : Union [str , bytes ]) -> bool :
608603 """Check if a channel is a keyevent notification channel."""
609- if isinstance (channel , bytes ):
610- channel = channel .decode ("utf-8" , errors = "replace" )
604+ channel = safe_str (channel )
611605 return channel .startswith (KeyeventChannel .PREFIX )
612606
613607
@@ -616,7 +610,9 @@ def is_keyspace_notification_channel(channel: Union[str, bytes]) -> bool:
616610 return is_keyspace_channel (channel ) or is_keyevent_channel (channel )
617611
618612
619- def _is_pattern (channel : Union [str , bytes , "KeyspaceChannel" , "KeyeventChannel" ]) -> bool :
613+ def _is_pattern (
614+ channel : Union [str , bytes , "KeyspaceChannel" , "KeyeventChannel" ],
615+ ) -> bool :
620616 """
621617 Check if a channel string contains glob-style pattern characters.
622618
@@ -636,8 +632,7 @@ def _is_pattern(channel: Union[str, bytes, "KeyspaceChannel", "KeyeventChannel"]
636632 # (KeyspaceChannel, KeyeventChannel)
637633 if hasattr (channel , "_channel_str" ):
638634 channel = channel ._channel_str
639- if isinstance (channel , bytes ):
640- channel = channel .decode ("utf-8" , errors = "replace" )
635+ channel = safe_str (channel )
641636 # Check for unescaped glob pattern characters
642637 # We look for *, ?, or [ that are not escaped with backslash
643638 i = 0
@@ -714,9 +709,19 @@ def __init__(
714709 self ._topology_check_interval : float = 1.0
715710 self ._last_topology_check : float = 0.0
716711
717- # Enable keyspace notifications on all nodes if requested
712+ # Store the notify-keyspace-events configuration for applying to new nodes
713+ self ._notify_keyspace_events : Optional [str ] = None
718714 if notify_keyspace_events is not None :
719- self ._configure_keyspace_notifications (notify_keyspace_events )
715+ # Normalize to string value for storage
716+ self ._notify_keyspace_events = (
717+ notify_keyspace_events .value
718+ if isinstance (notify_keyspace_events , NotifyKeyspaceEvents )
719+ else notify_keyspace_events
720+ )
721+
722+ # Enable keyspace notifications on all nodes if requested
723+ if self ._notify_keyspace_events is not None :
724+ self ._configure_keyspace_notifications (self ._notify_keyspace_events )
720725
721726 def _configure_keyspace_notifications (
722727 self , events : Union [str , NotifyKeyspaceEvents ]
@@ -729,17 +734,17 @@ def _configure_keyspace_notifications(
729734 (e.g., NotifyKeyspaceEvents.ALL or "KEA")
730735 """
731736 # Get the string value (handles both enum and plain strings)
732- events_str = events .value if isinstance (events , NotifyKeyspaceEvents ) else events
737+ events_str = (
738+ events .value if isinstance (events , NotifyKeyspaceEvents ) else events
739+ )
733740 for node in self ._get_all_primary_nodes ():
734741 redis_conn = self .cluster .get_redis_connection (node )
735742 redis_conn .config_set ("notify-keyspace-events" , events_str )
736743
737744 def _get_all_primary_nodes (self ):
738745 """Get all primary nodes in the cluster."""
739746 return [
740- node
741- for node in self .cluster .get_nodes ()
742- if node .server_type == "primary"
747+ node for node in self .cluster .get_nodes () if node .server_type == "primary"
743748 ]
744749
745750 def _ensure_node_pubsub (self , node ) -> Any :
@@ -752,9 +757,7 @@ def _ensure_node_pubsub(self, node) -> Any:
752757 self ._node_pubsubs [node .name ] = pubsub
753758 return self ._node_pubsubs [node .name ]
754759
755- def _subscribe_to_all_nodes (
756- self , channels : Dict [str , Any ], use_psubscribe : bool
757- ):
760+ def _subscribe_to_all_nodes (self , channels : Dict [str , Any ], use_psubscribe : bool ):
758761 """Subscribe to patterns/channels on all primary nodes."""
759762 primaries = self ._get_all_primary_nodes ()
760763
@@ -771,9 +774,7 @@ def _subscribe_to_all_nodes(
771774 else :
772775 pubsub .subscribe (** channels )
773776
774- def _unsubscribe_from_all_nodes (
775- self , channels : List [str ], use_punsubscribe : bool
776- ):
777+ def _unsubscribe_from_all_nodes (self , channels : List [str ], use_punsubscribe : bool ):
777778 """Unsubscribe from patterns/channels on all nodes."""
778779 for pubsub in self ._node_pubsubs .values ():
779780 if use_punsubscribe :
@@ -815,6 +816,7 @@ def subscribe(
815816 # Wrap the handler to convert raw messages to KeyNotification objects
816817 wrapped_handler = None
817818 if handler is not None :
819+
818820 def wrapped_handler (message ):
819821 notification = KeyNotification .from_message (message )
820822 if notification is not None :
@@ -824,8 +826,12 @@ def wrapped_handler(message):
824826 exact_channels = {}
825827
826828 for channel in channels :
827- # Convert Channel objects to strings for use as dict keys
828- channel_str = str (channel ) if hasattr (channel , "_channel_str" ) else channel
829+ # Convert Channel objects and bytes to strings for use as dict keys
830+ # (keyword arguments in **patterns/**exact_channels must be str)
831+ if hasattr (channel , "_channel_str" ):
832+ channel_str = str (channel )
833+ else :
834+ channel_str = safe_str (channel )
829835 if _is_pattern (channel ):
830836 patterns [channel_str ] = wrapped_handler
831837 else :
@@ -854,8 +860,12 @@ def unsubscribe(self, *channels: ChannelT):
854860 exact_channels = []
855861
856862 for channel in channels :
857- # Convert Channel objects to strings
858- channel_str = str (channel ) if hasattr (channel , "_channel_str" ) else channel
863+ # Convert Channel objects and bytes to strings
864+ # (must match the keys used in subscribe())
865+ if hasattr (channel , "_channel_str" ):
866+ channel_str = str (channel )
867+ else :
868+ channel_str = safe_str (channel )
859869 if _is_pattern (channel ):
860870 self ._subscribed_patterns .pop (channel_str , None )
861871 patterns .append (channel_str )
@@ -934,6 +944,51 @@ def _create_pubsub_iterator(self):
934944 return
935945 yield from pubsubs
936946
947+ def _poll_all_nodes_once (
948+ self , ignore_subscribe_messages : bool
949+ ) -> Optional [KeyNotification ]:
950+ """
951+ Perform a single non-blocking poll over all node pubsubs.
952+
953+ This is used when timeout=0 to match the expected semantics of
954+ PubSub.get_message(timeout=0) - a non-blocking check for messages.
955+
956+ Returns:
957+ A KeyNotification if one is available, None otherwise.
958+ """
959+ # Check for topology changes before polling
960+ self ._check_topology_changed ()
961+
962+ for pubsub in list (self ._node_pubsubs .values ()):
963+ try :
964+ message = pubsub .get_message (
965+ ignore_subscribe_messages = ignore_subscribe_messages ,
966+ timeout = 0.0 ,
967+ )
968+ except (MovedError , AskError ):
969+ self ._refresh_subscriptions_on_error ()
970+ continue
971+ except (ConnectionError , TimeoutError , RedisError ):
972+ self ._refresh_subscriptions_on_error ()
973+ continue
974+
975+ if message is not None :
976+ notification = KeyNotification .from_message (
977+ message , key_prefix = self .key_prefix
978+ )
979+ if notification is not None :
980+ # Call handler if registered
981+ handler = self ._subscribed_patterns .get (
982+ message .get ("pattern" )
983+ ) or self ._subscribed_channels .get (message .get ("channel" ))
984+ if handler :
985+ handler (notification )
986+ # Continue polling remaining nodes
987+ continue
988+ return notification
989+
990+ return None
991+
937992 def get_message (
938993 self ,
939994 ignore_subscribe_messages : bool = False ,
@@ -964,6 +1019,11 @@ def get_message(
9641019 if total_nodes == 0 :
9651020 return None
9661021
1022+ # Handle timeout=0 as a single non-blocking poll over all pubsubs
1023+ # This matches the expected semantics of PubSub.get_message(timeout=0)
1024+ if timeout == 0.0 :
1025+ return self ._poll_all_nodes_once (ignore_subscribe_messages )
1026+
9671027 # Calculate per-node timeout for each poll
9681028 # Use a small timeout per node to allow round-robin polling
9691029 per_node_timeout = min (0.1 , timeout / max (total_nodes , 1 ))
@@ -1116,6 +1176,18 @@ def refresh_subscriptions(self):
11161176 new_nodes = set (current_primaries .keys ()) - set (self ._node_pubsubs .keys ())
11171177 for node_name in new_nodes :
11181178 node = current_primaries [node_name ]
1179+
1180+ # Configure notify-keyspace-events on new node before subscribing
1181+ if self ._notify_keyspace_events is not None :
1182+ try :
1183+ redis_conn = self .cluster .get_redis_connection (node )
1184+ redis_conn .config_set (
1185+ "notify-keyspace-events" , self ._notify_keyspace_events
1186+ )
1187+ except Exception :
1188+ # Log or ignore - node may not be ready yet
1189+ pass
1190+
11191191 pubsub = self ._ensure_node_pubsub (node )
11201192
11211193 if self ._subscribed_patterns :
0 commit comments