1+ from enum import Enum
12import random
23import socket
34import sys
@@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs):
190191
191192 return connection_kwargs
192193
194+ class ReadFromReplicasMode (Enum ):
195+ ReadFromPrimary = 0
196+ ReadFromPrimaryAndReplica = 1
197+ ReadFromReplicaOnly = 2
198+
199+ @staticmethod
200+ def from_parameters (input : bool | "ReadFromReplicasMode" ):
201+ if input == True :
202+ return ReadFromReplicasMode .ReadFromPrimaryAndReplica
203+ elif input == False :
204+ return ReadFromReplicasMode .ReadFromPrimary
205+ if not input in ReadFromReplicasMode :
206+ raise RedisClusterException ("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode" )
207+ return input
208+
209+ def get_replica_mode_for_command (self , command : str ):
210+ if self == ReadFromReplicasMode .ReadFromPrimary :
211+ return ReadFromReplicasMode .ReadFromPrimary
212+ if not command in READ_COMMANDS :
213+ return ReadFromReplicasMode .ReadFromPrimary
214+ return self
193215
194216class ClusterParser (DefaultParser ):
195217 EXCEPTION_CLASSES = dict_merge (
@@ -503,7 +525,7 @@ def __init__(
503525 retry : Optional ["Retry" ] = None ,
504526 require_full_coverage : bool = False ,
505527 reinitialize_steps : int = 5 ,
506- read_from_replicas : bool = False ,
528+ read_from_replicas : bool | ReadFromReplicasMode = False ,
507529 dynamic_startup_nodes : bool = True ,
508530 url : Optional [str ] = None ,
509531 address_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
@@ -532,7 +554,9 @@ def __init__(
532554 Enable read from replicas in READONLY mode. You can read possibly
533555 stale data.
534556 When set to true, read commands will be assigned between the
535- primary and its replications in a Round-Robin manner.
557+ primary and its replications in a Round-Robin manner. When set to
558+ ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
559+ the replicas
536560 :param dynamic_startup_nodes:
537561 Set the RedisCluster's startup nodes to all of the discovered nodes.
538562 If true (default value), the cluster's discovered nodes will be used to
@@ -633,7 +657,7 @@ def __init__(
633657 self .cluster_error_retry_attempts = cluster_error_retry_attempts
634658 self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
635659 self .node_flags = self .__class__ .NODE_FLAGS .copy ()
636- self .read_from_replicas = read_from_replicas
660+ self .read_from_replicas_mode = ReadFromReplicasMode . from_parameters ( read_from_replicas )
637661 self .reinitialize_counter = 0
638662 self .reinitialize_steps = reinitialize_steps
639663 self .nodes_manager = NodesManager (
@@ -678,7 +702,7 @@ def on_connect(self, connection):
678702 connection .set_parser (ClusterParser )
679703 connection .on_connect ()
680704
681- if self .read_from_replicas :
705+ if self .read_from_replicas != ReadFromReplicasMode . ReadFromPrimary :
682706 # Sending READONLY command to server to configure connection as
683707 # readonly. Since each cluster node may change its server type due
684708 # to a failover, we should establish a READONLY connection
@@ -706,6 +730,13 @@ def get_primaries(self):
706730
707731 def get_replicas (self ):
708732 return self .nodes_manager .get_nodes_by_server_type (REPLICA )
733+
734+ def get_read_from_replica_mode_for_command (self , command : str ):
735+ if (
736+ (self .read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary ) or
737+ (not command in READ_COMMANDS )):
738+ return ReadFromReplicasMode .ReadFromPrimary
739+ return self .read_from_replicas_mode
709740
710741 def get_random_node (self ):
711742 return random .choice (list (self .nodes_manager .nodes_cache .values ()))
@@ -804,7 +835,7 @@ def pipeline(self, transaction=None, shard_hint=None):
804835 result_callbacks = self .result_callbacks ,
805836 cluster_response_callbacks = self .cluster_response_callbacks ,
806837 cluster_error_retry_attempts = self .cluster_error_retry_attempts ,
807- read_from_replicas = self .read_from_replicas ,
838+ read_from_replicas_mode = self .read_from_replicas_mode ,
808839 reinitialize_steps = self .reinitialize_steps ,
809840 lock = self ._lock ,
810841 )
@@ -922,7 +953,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
922953 # get the node that holds the key's slot
923954 slot = self .determine_slot (* args )
924955 node = self .nodes_manager .get_node_from_slot (
925- slot , self .read_from_replicas and command in READ_COMMANDS
956+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
926957 )
927958 return [node ]
928959
@@ -1144,7 +1175,7 @@ def _execute_command(self, target_node, *args, **kwargs):
11441175 # refresh the target node
11451176 slot = self .determine_slot (* args )
11461177 target_node = self .nodes_manager .get_node_from_slot (
1147- slot , self .read_from_replicas and command in READ_COMMANDS
1178+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
11481179 )
11491180 moved = False
11501181
@@ -1293,7 +1324,6 @@ def __del__(self):
12931324 if self .redis_connection is not None :
12941325 self .redis_connection .close ()
12951326
1296-
12971327class LoadBalancer :
12981328 """
12991329 Round-Robin Load Balancing
@@ -1302,11 +1332,30 @@ class LoadBalancer:
13021332 def __init__ (self , start_index : int = 0 ) -> None :
13031333 self .primary_to_idx = {}
13041334 self .start_index = start_index
1305-
1306- def get_server_index (self , primary : str , list_size : int ) -> int :
1307- server_index = self .primary_to_idx .setdefault (primary , self .start_index )
1308- # Update the index
1309- self .primary_to_idx [primary ] = (server_index + 1 ) % list_size
1335+
1336+ def get_node_from_slot (self , slot_index : int , slot_nodes : list [ClusterNode ] | None , read_from_replicas_mode : ReadFromReplicasMode ):
1337+ if slot_nodes is None or len (slot_nodes ) == 0 :
1338+ raise SlotNotCoveredError (
1339+ f'Slot "{ slot_index } " not covered by the cluster. '
1340+ )
1341+ if read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary :
1342+ node_idx = 0
1343+ else :
1344+ skip_primary = read_from_replicas_mode == ReadFromReplicasMode .ReadFromReplicaOnly
1345+ # get the server index in a Round-Robin manner
1346+ primary_name = slot_nodes [0 ].name
1347+ node_idx = self .read_load_balancer .get_server_index (
1348+ primary_name , len (slot_nodes ), skip_primary
1349+ )
1350+ return slot_nodes [node_idx ]
1351+
1352+ def get_server_index (self , primary : str , list_size : int , skip_primary :bool ) -> int :
1353+ # default to -1 if not found, so after incrementing it will be 0
1354+ server_index = (self .primary_to_idx .get (primary , - 1 ) + 1 ) % list_size
1355+ # If we skip primary, skip the zero-index node.
1356+ if skip_primary and server_index == 0 and list_size > 1 :
1357+ server_index = server_index + 1
1358+ self .primary_to_idx [primary ] = server_index
13101359 return server_index
13111360
13121361 def reset (self ) -> None :
@@ -1401,41 +1450,23 @@ def _update_moved_slots(self):
14011450 # Reset moved_exception
14021451 self ._moved_exception = None
14031452
1404- def get_node_from_slot (self , slot , read_from_replicas = False , server_type = None ):
1453+ def get_node_from_slot (
1454+ self , slot : int , read_from_replicas_mode : ReadFromReplicasMode
1455+ ) -> "ClusterNode" :
14051456 """
14061457 Gets a node that servers this hash slot
14071458 """
14081459 if self ._moved_exception :
14091460 with self ._lock :
14101461 if self ._moved_exception :
14111462 self ._update_moved_slots ()
1412-
1413- if self .slots_cache .get (slot ) is None or len (self .slots_cache [slot ]) == 0 :
1414- raise SlotNotCoveredError (
1415- f'Slot "{ slot } " not covered by the cluster. '
1416- f'"require_full_coverage={ self ._require_full_coverage } "'
1417- )
1418-
1419- if read_from_replicas is True :
1420- # get the server index in a Round-Robin manner
1421- primary_name = self .slots_cache [slot ][0 ].name
1422- node_idx = self .read_load_balancer .get_server_index (
1423- primary_name , len (self .slots_cache [slot ])
1463+
1464+ return self .read_load_balancer .get_node_from_slot (
1465+ slot ,
1466+ self .slots_cache .get (slot , None ),
1467+ read_from_replicas_mode ,
14241468 )
1425- elif (
1426- server_type is None
1427- or server_type == PRIMARY
1428- or len (self .slots_cache [slot ]) == 1
1429- ):
1430- # return a primary
1431- node_idx = 0
1432- else :
1433- # return a replica
1434- # randomly choose one of the replicas
1435- node_idx = random .randint (1 , len (self .slots_cache [slot ]) - 1 )
1436-
1437- return self .slots_cache [slot ][node_idx ]
1438-
1469+
14391470 def get_nodes_by_server_type (self , server_type ):
14401471 """
14411472 Get all nodes with the specified server type
@@ -1775,7 +1806,7 @@ def execute_command(self, *args):
17751806 channel = args [1 ]
17761807 slot = self .cluster .keyslot (channel )
17771808 node = self .cluster .nodes_manager .get_node_from_slot (
1778- slot , self .cluster .read_from_replicas
1809+ slot , self .cluster .read_from_replicas_mode
17791810 )
17801811 else :
17811812 # Get a random node
@@ -1915,7 +1946,7 @@ def __init__(
19151946 result_callbacks : Optional [Dict [str , Callable ]] = None ,
19161947 cluster_response_callbacks : Optional [Dict [str , Callable ]] = None ,
19171948 startup_nodes : Optional [List ["ClusterNode" ]] = None ,
1918- read_from_replicas : bool = False ,
1949+ read_from_replicas_mode : ReadFromReplicasMode = ReadFromReplicasMode . ReadFromPrimary ,
19191950 cluster_error_retry_attempts : int = 3 ,
19201951 reinitialize_steps : int = 5 ,
19211952 lock = None ,
@@ -1930,7 +1961,7 @@ def __init__(
19301961 result_callbacks or self .__class__ .RESULT_CALLBACKS .copy ()
19311962 )
19321963 self .startup_nodes = startup_nodes if startup_nodes else []
1933- self .read_from_replicas = read_from_replicas
1964+ self .read_from_replicas_mode = read_from_replicas_mode
19341965 self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
19351966 self .cluster_response_callbacks = cluster_response_callbacks
19361967 self .cluster_error_retry_attempts = cluster_error_retry_attempts
0 commit comments