Skip to content

Commit e14d680

Browse files
committed
Added support for async cluster
1 parent 835ede7 commit e14d680

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

redis/asyncio/cluster.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import socket
55
import ssl
6+
import threading
67
import warnings
78
from typing import (
89
Any,
@@ -29,6 +30,7 @@
2930
from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url
3031
from redis.asyncio.lock import Lock
3132
from redis.asyncio.retry import Retry
33+
from redis.auth.token import TokenInterface
3234
from redis.backoff import default_backoff
3335
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
3436
from redis.cluster import (
@@ -45,6 +47,7 @@
4547
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
4648
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
4749
from redis.credentials import CredentialProvider
50+
from redis.event import EventDispatcher, AsyncAfterConnectionReleasedEvent, AfterAsyncClusterInstantiationEvent
4851
from redis.exceptions import (
4952
AskError,
5053
BusyLoadingError,
@@ -60,7 +63,7 @@
6063
ResponseError,
6164
SlotNotCoveredError,
6265
TimeoutError,
63-
TryAgainError,
66+
TryAgainError, RedisError,
6467
)
6568
from redis.typing import AnyKeyT, EncodableT, KeyT
6669
from redis.utils import (
@@ -270,6 +273,7 @@ def __init__(
270273
ssl_ciphers: Optional[str] = None,
271274
protocol: Optional[int] = 2,
272275
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
276+
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
273277
) -> None:
274278
if db:
275279
raise RedisClusterException(
@@ -371,6 +375,7 @@ def __init__(
371375
require_full_coverage,
372376
kwargs,
373377
address_remap=address_remap,
378+
event_dispatcher=event_dispatcher,
374379
)
375380
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
376381
self.read_from_replicas = read_from_replicas
@@ -929,6 +934,8 @@ class ClusterNode:
929934
__slots__ = (
930935
"_connections",
931936
"_free",
937+
"_lock",
938+
"_event_dispatcher",
932939
"connection_class",
933940
"connection_kwargs",
934941
"host",
@@ -966,6 +973,9 @@ def __init__(
966973

967974
self._connections: List[Connection] = []
968975
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
976+
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
977+
if self._event_dispatcher is None:
978+
self._event_dispatcher = EventDispatcher()
969979

970980
def __repr__(self) -> str:
971981
return (
@@ -1082,10 +1092,37 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10821092

10831093
return ret
10841094

1095+
async def re_auth_callback(self, token: TokenInterface):
1096+
tmp_queue = collections.deque()
1097+
while self._free:
1098+
conn = self._free.popleft()
1099+
await conn.retry.call_with_retry(
1100+
lambda: conn.send_command('AUTH', token.try_get('oid'), token.get_value()),
1101+
lambda error: self._mock(error)
1102+
)
1103+
await conn.retry.call_with_retry(
1104+
lambda: conn.read_response(),
1105+
lambda error: self._mock(error)
1106+
)
1107+
tmp_queue.append(conn)
1108+
1109+
while tmp_queue:
1110+
conn = tmp_queue.popleft()
1111+
self._free.append(conn)
1112+
1113+
async def _mock(self, error: RedisError):
1114+
"""
1115+
Dummy functions, needs to be passed as error callback to retry object.
1116+
:param error:
1117+
:return:
1118+
"""
1119+
pass
1120+
10851121

10861122
class NodesManager:
10871123
__slots__ = (
10881124
"_moved_exception",
1125+
"_event_dispatcher",
10891126
"connection_kwargs",
10901127
"default_node",
10911128
"nodes_cache",
@@ -1102,6 +1139,7 @@ def __init__(
11021139
require_full_coverage: bool,
11031140
connection_kwargs: Dict[str, Any],
11041141
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1142+
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
11051143
) -> None:
11061144
self.startup_nodes = {node.name: node for node in startup_nodes}
11071145
self.require_full_coverage = require_full_coverage
@@ -1113,6 +1151,7 @@ def __init__(
11131151
self.slots_cache: Dict[int, List["ClusterNode"]] = {}
11141152
self.read_load_balancer = LoadBalancer()
11151153
self._moved_exception: MovedError = None
1154+
self._event_dispatcher = event_dispatcher
11161155

11171156
def get_node(
11181157
self,
@@ -1230,6 +1269,11 @@ async def initialize(self) -> None:
12301269
try:
12311270
# Make sure cluster mode is enabled on this node
12321271
try:
1272+
self._event_dispatcher.dispatch(
1273+
AfterAsyncClusterInstantiationEvent(
1274+
self.nodes_cache,
1275+
self.connection_kwargs.get("credential_provider", None))
1276+
)
12331277
cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
12341278
except ResponseError:
12351279
raise RedisClusterException(

redis/event.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from enum import Enum
55
from typing import List, Union, Optional
66

7+
from redis.auth.token import TokenInterface
78
from redis.credentials import StreamingCredentialProvider, CredentialProvider
89

910

@@ -57,6 +58,9 @@ def __init__(self):
5758
AfterSingleConnectionInstantiationEvent: [
5859
RegisterReAuthForSingleConnection()
5960
],
61+
AfterAsyncClusterInstantiationEvent: [
62+
RegisterReAuthForAsyncClusterNodes()
63+
],
6064
AsyncAfterConnectionReleasedEvent: [
6165
AsyncReAuthConnectionListener(),
6266
],
@@ -154,6 +158,29 @@ def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]:
154158
return self._connection_lock
155159

156160

161+
class AfterAsyncClusterInstantiationEvent:
162+
"""
163+
Event that will be fired after async cluster instance was created.
164+
165+
Async cluster doesn't use connection pools, instead ClusterNode object manages connections.
166+
"""
167+
def __init__(
168+
self,
169+
nodes: dict,
170+
credential_provider: Optional[CredentialProvider] = None,
171+
):
172+
self._nodes = nodes
173+
self._credential_provider = credential_provider
174+
175+
@property
176+
def nodes(self) -> dict:
177+
return self._nodes
178+
179+
@property
180+
def credential_provider(self) -> Union[CredentialProvider, None]:
181+
return self._credential_provider
182+
183+
157184
class ReAuthConnectionListener(EventListenerInterface):
158185
"""
159186
Listener that performs re-authentication of given connection.
@@ -225,3 +252,17 @@ async def _re_auth_async(self, token):
225252
async with self._event.connection_lock:
226253
await self._event.connection.send_command('AUTH', token.try_get('oid'), token.get_value())
227254
await self._event.connection.read_response()
255+
256+
257+
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
258+
def __init__(self):
259+
self._event = None
260+
261+
def listen(self, event: AfterAsyncClusterInstantiationEvent):
262+
if isinstance(event.credential_provider, StreamingCredentialProvider):
263+
self._event = event
264+
event.credential_provider.on_next(self._re_auth)
265+
266+
async def _re_auth(self, token: TokenInterface):
267+
for key in self._event.nodes:
268+
await self._event.nodes[key].re_auth_callback(token)

0 commit comments

Comments
 (0)