Skip to content

Commit 6556e8a

Browse files
committed
Added graceful connection closing, added graceful hc tasks termination
1 parent 115d996 commit 6556e8a

File tree

5 files changed

+40
-12
lines changed

5 files changed

+40
-12
lines changed

redis/asyncio/multidb/client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def __init__(self, config: MultiDbConfig):
5959
self._hc_lock = asyncio.Lock()
6060
self._bg_scheduler = BackgroundScheduler()
6161
self._config = config
62-
self._hc_task = None
62+
self._recurring_hc_task = None
63+
self._hc_tasks = []
6364
self._half_open_state_task = None
6465

6566
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
@@ -68,10 +69,12 @@ async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
6869
return self
6970

7071
async def __aexit__(self, exc_type, exc_value, traceback):
71-
if self._hc_task:
72-
self._hc_task.cancel()
72+
if self._recurring_hc_task:
73+
self._recurring_hc_task.cancel()
7374
if self._half_open_state_task:
7475
self._half_open_state_task.cancel()
76+
for hc_task in self._hc_tasks:
77+
hc_task.cancel()
7578

7679
async def initialize(self):
7780
"""
@@ -84,7 +87,7 @@ async def raise_exception_on_failed_hc(error):
8487
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
8588

8689
# Starts recurring health checks on the background.
87-
self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async(
90+
self._recurring_hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async(
8891
self._health_check_interval,
8992
self._check_databases_health,
9093
))
@@ -251,12 +254,10 @@ async def _check_databases_health(
251254
Runs health checks against all databases.
252255
"""
253256
try:
257+
self._hc_tasks = [asyncio.create_task(self._check_db_health(database)) for database, _ in self._databases]
254258
results = await asyncio.wait_for(
255259
asyncio.gather(
256-
*(
257-
asyncio.create_task(self._check_db_health(database))
258-
for database, _ in self._databases
259-
),
260+
*self._hc_tasks,
260261
return_exceptions=True,
261262
),
262263
timeout=self._health_check_interval,

redis/asyncio/multidb/command_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from redis.asyncio.client import PubSub, Pipeline
88
from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database
99
from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \
10-
ResubscribeOnActiveDatabaseChanged
10+
ResubscribeOnActiveDatabaseChanged, CloseConnectionOnActiveDatabaseChanged
1111
from redis.asyncio.multidb.failover import AsyncFailoverStrategy, FailoverStrategyExecutor, DefaultFailoverStrategyExecutor, \
1212
DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY
1313
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
@@ -286,7 +286,8 @@ def _setup_event_dispatcher(self):
286286
"""
287287
failure_listener = RegisterCommandFailure(self._failure_detectors)
288288
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
289+
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
289290
self._event_dispatcher.register_listeners({
290291
AsyncOnCommandsFailEvent: [failure_listener],
291-
AsyncActiveDatabaseChanged: [resubscribe_listener],
292+
AsyncActiveDatabaseChanged: [close_connection_listener, resubscribe_listener],
292293
})

redis/asyncio/multidb/event.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22

3+
from redis.asyncio import Redis
34
from redis.asyncio.multidb.database import AsyncDatabase
45
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
56
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
@@ -53,6 +54,16 @@ async def listen(self, event: AsyncActiveDatabaseChanged):
5354
event.command_executor.active_pubsub = new_pubsub
5455
await old_pubsub.aclose()
5556

57+
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
58+
"""
59+
Close connection to the old active database.
60+
"""
61+
async def listen(self, event: AsyncActiveDatabaseChanged):
62+
await event.old_database.client.aclose()
63+
64+
if isinstance(event.old_database.client, Redis):
65+
await event.old_database.client.connection_pool.disconnect()
66+
5667
class RegisterCommandFailure(AsyncEventListenerInterface):
5768
"""
5869
Event listener that registers command failures and passing it to the failure detectors.

redis/multidb/command_executor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
88
from redis.multidb.database import Database, Databases, SyncDatabase
99
from redis.multidb.circuit import State as CBState
10-
from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged
10+
from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged, \
11+
CloseConnectionOnActiveDatabaseChanged
1112
from redis.multidb.failover import FailoverStrategy, FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \
1213
DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor
1314
from redis.multidb.failure_detector import FailureDetector
@@ -303,7 +304,8 @@ def _setup_event_dispatcher(self):
303304
"""
304305
failure_listener = RegisterCommandFailure(self._failure_detectors)
305306
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
307+
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
306308
self._event_dispatcher.register_listeners({
307309
OnCommandsFailEvent: [failure_listener],
308-
ActiveDatabaseChanged: [resubscribe_listener],
310+
ActiveDatabaseChanged: [close_connection_listener, resubscribe_listener],
309311
})

redis/multidb/event.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import List
22

3+
from redis.client import Redis
4+
from sphinx.events import EventListener
5+
36
from redis.event import EventListenerInterface, OnCommandsFailEvent
47
from redis.multidb.database import SyncDatabase
58
from redis.multidb.failure_detector import FailureDetector
@@ -53,6 +56,16 @@ def listen(self, event: ActiveDatabaseChanged):
5356
event.command_executor.active_pubsub = new_pubsub
5457
old_pubsub.close()
5558

59+
class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface):
60+
"""
61+
Close connection to the old active database.
62+
"""
63+
def listen(self, event: ActiveDatabaseChanged):
64+
event.old_database.client.close()
65+
66+
if isinstance(event.old_database.client, Redis):
67+
event.old_database.client.connection_pool.disconnect()
68+
5669
class RegisterCommandFailure(EventListenerInterface):
5770
"""
5871
Event listener that registers command failures and passing it to the failure detectors.

0 commit comments

Comments
 (0)