33import random
44import socket
55import ssl
6+ import threading
67import warnings
78from typing import (
89 Any ,
2930from redis .asyncio .connection import Connection , DefaultParser , SSLConnection , parse_url
3031from redis .asyncio .lock import Lock
3132from redis .asyncio .retry import Retry
33+ from redis .auth .token import TokenInterface
3234from redis .backoff import default_backoff
3335from redis .client import EMPTY_RESPONSE , NEVER_DECODE , AbstractRedis
3436from redis .cluster import (
4547from redis .commands import READ_COMMANDS , AsyncRedisClusterCommands
4648from redis .crc import REDIS_CLUSTER_HASH_SLOTS , key_slot
4749from redis .credentials import CredentialProvider
50+ from redis .event import EventDispatcher , AsyncAfterConnectionReleasedEvent , AfterAsyncClusterInstantiationEvent
4851from redis .exceptions import (
4952 AskError ,
5053 BusyLoadingError ,
6063 ResponseError ,
6164 SlotNotCoveredError ,
6265 TimeoutError ,
63- TryAgainError ,
66+ TryAgainError , RedisError ,
6467)
6568from redis .typing import AnyKeyT , EncodableT , KeyT
6669from redis .utils import (
@@ -270,6 +273,7 @@ def __init__(
270273 ssl_ciphers : Optional [str ] = None ,
271274 protocol : Optional [int ] = 2 ,
272275 address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
276+ event_dispatcher : Optional [EventDispatcher ] = EventDispatcher (),
273277 ) -> None :
274278 if db :
275279 raise RedisClusterException (
@@ -371,6 +375,7 @@ def __init__(
371375 require_full_coverage ,
372376 kwargs ,
373377 address_remap = address_remap ,
378+ event_dispatcher = event_dispatcher ,
374379 )
375380 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
376381 self .read_from_replicas = read_from_replicas
@@ -929,6 +934,8 @@ class ClusterNode:
929934 __slots__ = (
930935 "_connections" ,
931936 "_free" ,
937+ "_lock" ,
938+ "_event_dispatcher" ,
932939 "connection_class" ,
933940 "connection_kwargs" ,
934941 "host" ,
@@ -966,6 +973,9 @@ def __init__(
966973
967974 self ._connections : List [Connection ] = []
968975 self ._free : Deque [Connection ] = collections .deque (maxlen = self .max_connections )
976+ self ._event_dispatcher = self .connection_kwargs .get ("event_dispatcher" , None )
977+ if self ._event_dispatcher is None :
978+ self ._event_dispatcher = EventDispatcher ()
969979
970980 def __repr__ (self ) -> str :
971981 return (
@@ -1082,10 +1092,37 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10821092
10831093 return ret
10841094
1095+ async def re_auth_callback (self , token : TokenInterface ):
1096+ tmp_queue = collections .deque ()
1097+ while self ._free :
1098+ conn = self ._free .popleft ()
1099+ await conn .retry .call_with_retry (
1100+ lambda : conn .send_command ('AUTH' , token .try_get ('oid' ), token .get_value ()),
1101+ lambda error : self ._mock (error )
1102+ )
1103+ await conn .retry .call_with_retry (
1104+ lambda : conn .read_response (),
1105+ lambda error : self ._mock (error )
1106+ )
1107+ tmp_queue .append (conn )
1108+
1109+ while tmp_queue :
1110+ conn = tmp_queue .popleft ()
1111+ self ._free .append (conn )
1112+
1113+ async def _mock (self , error : RedisError ):
1114+ """
1115+ Dummy functions, needs to be passed as error callback to retry object.
1116+ :param error:
1117+ :return:
1118+ """
1119+ pass
1120+
10851121
10861122class NodesManager :
10871123 __slots__ = (
10881124 "_moved_exception" ,
1125+ "_event_dispatcher" ,
10891126 "connection_kwargs" ,
10901127 "default_node" ,
10911128 "nodes_cache" ,
@@ -1102,6 +1139,7 @@ def __init__(
11021139 require_full_coverage : bool ,
11031140 connection_kwargs : Dict [str , Any ],
11041141 address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
1142+ event_dispatcher : Optional [EventDispatcher ] = EventDispatcher (),
11051143 ) -> None :
11061144 self .startup_nodes = {node .name : node for node in startup_nodes }
11071145 self .require_full_coverage = require_full_coverage
@@ -1113,6 +1151,7 @@ def __init__(
11131151 self .slots_cache : Dict [int , List ["ClusterNode" ]] = {}
11141152 self .read_load_balancer = LoadBalancer ()
11151153 self ._moved_exception : MovedError = None
1154+ self ._event_dispatcher = event_dispatcher
11161155
11171156 def get_node (
11181157 self ,
@@ -1230,6 +1269,11 @@ async def initialize(self) -> None:
12301269 try :
12311270 # Make sure cluster mode is enabled on this node
12321271 try :
1272+ self ._event_dispatcher .dispatch (
1273+ AfterAsyncClusterInstantiationEvent (
1274+ self .nodes_cache ,
1275+ self .connection_kwargs .get ("credential_provider" , None ))
1276+ )
12331277 cluster_slots = await startup_node .execute_command ("CLUSTER SLOTS" )
12341278 except ResponseError :
12351279 raise RedisClusterException (
0 commit comments