@@ -59,7 +59,8 @@ def __init__(self, config: MultiDbConfig):
59
59
self ._hc_lock = asyncio .Lock ()
60
60
self ._bg_scheduler = BackgroundScheduler ()
61
61
self ._config = config
62
- self ._hc_task = None
62
+ self ._recurring_hc_task = None
63
+ self ._hc_tasks = []
63
64
self ._half_open_state_task = None
64
65
65
66
async def __aenter__ (self : "MultiDBClient" ) -> "MultiDBClient" :
@@ -68,10 +69,12 @@ async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
68
69
return self
69
70
70
71
async def __aexit__ (self , exc_type , exc_value , traceback ):
71
- if self ._hc_task :
72
- self ._hc_task .cancel ()
72
+ if self ._recurring_hc_task :
73
+ self ._recurring_hc_task .cancel ()
73
74
if self ._half_open_state_task :
74
75
self ._half_open_state_task .cancel ()
76
+ for hc_task in self ._hc_tasks :
77
+ hc_task .cancel ()
75
78
76
79
async def initialize (self ):
77
80
"""
@@ -84,7 +87,7 @@ async def raise_exception_on_failed_hc(error):
84
87
await self ._check_databases_health (on_error = raise_exception_on_failed_hc )
85
88
86
89
# Starts recurring health checks on the background.
87
- self ._hc_task = asyncio .create_task (self ._bg_scheduler .run_recurring_async (
90
+ self ._recurring_hc_task = asyncio .create_task (self ._bg_scheduler .run_recurring_async (
88
91
self ._health_check_interval ,
89
92
self ._check_databases_health ,
90
93
))
@@ -251,12 +254,10 @@ async def _check_databases_health(
251
254
Runs health checks against all databases.
252
255
"""
253
256
try :
257
+ self ._hc_tasks = [asyncio .create_task (self ._check_db_health (database )) for database , _ in self ._databases ]
254
258
results = await asyncio .wait_for (
255
259
asyncio .gather (
256
- * (
257
- asyncio .create_task (self ._check_db_health (database ))
258
- for database , _ in self ._databases
259
- ),
260
+ * self ._hc_tasks ,
260
261
return_exceptions = True ,
261
262
),
262
263
timeout = self ._health_check_interval ,
0 commit comments