Skip to content

Commit 3ef34b1

Browse files
committed
Close cluster connection on failover
1 parent 063e795 commit 3ef34b1

File tree

4 files changed

+24
-1
lines changed

4 files changed

+24
-1
lines changed

redis/asyncio/connection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def __init__(
212212
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
213213
self._buffer_cutoff = 6000
214214
self._re_auth_token: Optional[TokenInterface] = None
215+
self._should_reconnect = False
215216

216217
try:
217218
p = int(protocol)
@@ -342,6 +343,12 @@ async def connect_check_health(
342343
if task and inspect.isawaitable(task):
343344
await task
344345

346+
def mark_for_reconnect(self):
347+
self._should_reconnect = True
348+
349+
def should_reconnect(self):
350+
return self._should_reconnect
351+
345352
@abstractmethod
346353
async def _connect(self):
347354
pass
@@ -1198,6 +1205,9 @@ async def release(self, connection: AbstractConnection):
11981205
# Connections should always be returned to the correct pool,
11991206
# not doing so is an error that will cause an exception here.
12001207
self._in_use_connections.remove(connection)
1208+
if connection.should_reconnect():
1209+
await connection.disconnect()
1210+
12011211
self._available_connections.append(connection)
12021212
await self._event_dispatcher.dispatch_async(
12031213
AsyncAfterConnectionReleasedEvent(connection)
@@ -1225,6 +1235,14 @@ async def disconnect(self, inuse_connections: bool = True):
12251235
if exc:
12261236
raise exc
12271237

1238+
async def update_active_connections_for_reconnect(self):
1239+
"""
1240+
Mark all active connections for reconnect.
1241+
"""
1242+
async with self._lock:
1243+
for conn in self._in_use_connections:
1244+
conn.mark_for_reconnect()
1245+
12281246
async def aclose(self) -> None:
12291247
"""Close the pool, disconnecting all connections"""
12301248
await self.disconnect()

redis/asyncio/multidb/event.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ async def listen(self, event: AsyncActiveDatabaseChanged):
6262
await event.old_database.client.aclose()
6363

6464
if isinstance(event.old_database.client, Redis):
65+
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
6566
await event.old_database.client.connection_pool.disconnect()
6667

6768
class RegisterCommandFailure(AsyncEventListenerInterface):

redis/multidb/event.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def listen(self, event: ActiveDatabaseChanged):
6565
if isinstance(event.old_database.client, Redis):
6666
event.old_database.client.connection_pool.update_active_connections_for_reconnect()
6767
event.old_database.client.connection_pool.disconnect()
68+
else:
69+
for node in event.old_database.client.nodes_manager.nodes_cache.values():
70+
node.redis_connection.connection_pool.update_active_connections_for_reconnect()
71+
node.redis_connection.connection_pool.disconnect()
6872

6973
class RegisterCommandFailure(EventListenerInterface):
7074
"""

tests/test_asyncio/test_scenario/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive
4545

4646
# Retry configuration different for health checks as initial health check require more time in case
4747
# if infrastructure wasn't restored from the previous test.
48-
health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL)
48+
health_check_interval = request.param.get('health_check_interval', 10)
4949
health_checks = request.param.get('health_checks', [])
5050
event_dispatcher = EventDispatcher()
5151
listener = CheckActiveDatabaseChangedListener()

0 commit comments

Comments
 (0)