3
3
import random
4
4
import socket
5
5
import ssl
6
+ import threading
6
7
import warnings
7
8
from typing import (
8
9
Any ,
29
30
from redis .asyncio .connection import Connection , DefaultParser , SSLConnection , parse_url
30
31
from redis .asyncio .lock import Lock
31
32
from redis .asyncio .retry import Retry
33
+ from redis .auth .token import TokenInterface
32
34
from redis .backoff import default_backoff
33
35
from redis .client import EMPTY_RESPONSE , NEVER_DECODE , AbstractRedis
34
36
from redis .cluster import (
45
47
from redis .commands import READ_COMMANDS , AsyncRedisClusterCommands
46
48
from redis .crc import REDIS_CLUSTER_HASH_SLOTS , key_slot
47
49
from redis .credentials import CredentialProvider
50
+ from redis .event import EventDispatcher , AsyncAfterConnectionReleasedEvent , AfterAsyncClusterInstantiationEvent
48
51
from redis .exceptions import (
49
52
AskError ,
50
53
BusyLoadingError ,
60
63
ResponseError ,
61
64
SlotNotCoveredError ,
62
65
TimeoutError ,
63
- TryAgainError ,
66
+ TryAgainError , RedisError ,
64
67
)
65
68
from redis .typing import AnyKeyT , EncodableT , KeyT
66
69
from redis .utils import (
@@ -270,6 +273,7 @@ def __init__(
270
273
ssl_ciphers : Optional [str ] = None ,
271
274
protocol : Optional [int ] = 2 ,
272
275
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
276
+ event_dispatcher : Optional [EventDispatcher ] = EventDispatcher (),
273
277
) -> None :
274
278
if db :
275
279
raise RedisClusterException (
@@ -371,6 +375,7 @@ def __init__(
371
375
require_full_coverage ,
372
376
kwargs ,
373
377
address_remap = address_remap ,
378
+ event_dispatcher = event_dispatcher ,
374
379
)
375
380
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
376
381
self .read_from_replicas = read_from_replicas
@@ -929,6 +934,8 @@ class ClusterNode:
929
934
__slots__ = (
930
935
"_connections" ,
931
936
"_free" ,
937
+ "_lock" ,
938
+ "_event_dispatcher" ,
932
939
"connection_class" ,
933
940
"connection_kwargs" ,
934
941
"host" ,
@@ -966,6 +973,9 @@ def __init__(
966
973
967
974
self ._connections : List [Connection ] = []
968
975
self ._free : Deque [Connection ] = collections .deque (maxlen = self .max_connections )
976
+ self ._event_dispatcher = self .connection_kwargs .get ("event_dispatcher" , None )
977
+ if self ._event_dispatcher is None :
978
+ self ._event_dispatcher = EventDispatcher ()
969
979
970
980
def __repr__ (self ) -> str :
971
981
return (
@@ -1082,10 +1092,37 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1082
1092
1083
1093
return ret
1084
1094
1095
+ async def re_auth_callback (self , token : TokenInterface ):
1096
+ tmp_queue = collections .deque ()
1097
+ while self ._free :
1098
+ conn = self ._free .popleft ()
1099
+ await conn .retry .call_with_retry (
1100
+ lambda : conn .send_command ('AUTH' , token .try_get ('oid' ), token .get_value ()),
1101
+ lambda error : self ._mock (error )
1102
+ )
1103
+ await conn .retry .call_with_retry (
1104
+ lambda : conn .read_response (),
1105
+ lambda error : self ._mock (error )
1106
+ )
1107
+ tmp_queue .append (conn )
1108
+
1109
+ while tmp_queue :
1110
+ conn = tmp_queue .popleft ()
1111
+ self ._free .append (conn )
1112
+
1113
+ async def _mock (self , error : RedisError ):
1114
+ """
1115
+ Dummy functions, needs to be passed as error callback to retry object.
1116
+ :param error:
1117
+ :return:
1118
+ """
1119
+ pass
1120
+
1085
1121
1086
1122
class NodesManager :
1087
1123
__slots__ = (
1088
1124
"_moved_exception" ,
1125
+ "_event_dispatcher" ,
1089
1126
"connection_kwargs" ,
1090
1127
"default_node" ,
1091
1128
"nodes_cache" ,
@@ -1102,6 +1139,7 @@ def __init__(
1102
1139
require_full_coverage : bool ,
1103
1140
connection_kwargs : Dict [str , Any ],
1104
1141
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
1142
+ event_dispatcher : Optional [EventDispatcher ] = EventDispatcher (),
1105
1143
) -> None :
1106
1144
self .startup_nodes = {node .name : node for node in startup_nodes }
1107
1145
self .require_full_coverage = require_full_coverage
@@ -1113,6 +1151,7 @@ def __init__(
1113
1151
self .slots_cache : Dict [int , List ["ClusterNode" ]] = {}
1114
1152
self .read_load_balancer = LoadBalancer ()
1115
1153
self ._moved_exception : MovedError = None
1154
+ self ._event_dispatcher = event_dispatcher
1116
1155
1117
1156
def get_node (
1118
1157
self ,
@@ -1230,6 +1269,11 @@ async def initialize(self) -> None:
1230
1269
try :
1231
1270
# Make sure cluster mode is enabled on this node
1232
1271
try :
1272
+ self ._event_dispatcher .dispatch (
1273
+ AfterAsyncClusterInstantiationEvent (
1274
+ self .nodes_cache ,
1275
+ self .connection_kwargs .get ("credential_provider" , None ))
1276
+ )
1233
1277
cluster_slots = await startup_node .execute_command ("CLUSTER SLOTS" )
1234
1278
except ResponseError :
1235
1279
raise RedisClusterException (
0 commit comments