@@ -379,6 +379,12 @@ def __init__(
379
379
self ._initialize = True
380
380
self ._lock : Optional [asyncio .Lock ] = None
381
381
382
+ # When used as an async context manager, we need to increment and decrement
383
+ # a usage counter so that we can close the connection pool when no one is
384
+ # using the client.
385
+ self ._usage_counter = 0
386
+ self ._usage_lock = asyncio .Lock ()
387
+
382
388
async def initialize (self ) -> "RedisCluster" :
383
389
"""Get all nodes from startup nodes & creates connections if not initialized."""
384
390
if self ._initialize :
@@ -415,10 +421,40 @@ async def close(self) -> None:
415
421
await self .aclose ()
416
422
417
423
async def __aenter__ (self ) -> "RedisCluster" :
418
- return await self .initialize ()
424
+ """
425
+ Async context manager entry. Increments a usage counter so that the
426
+ connection pool is only closed (via aclose()) when no context is using
427
+ the client.
428
+ """
429
+ async with self ._usage_lock :
430
+ self ._usage_counter += 1
431
+ try :
432
+ # Initialize the client (i.e. establish connection, etc.)
433
+ return await self .initialize ()
434
+ except Exception :
435
+ # If initialization fails, decrement the counter to keep it in sync
436
+ async with self ._usage_lock :
437
+ self ._usage_counter -= 1
438
+ raise
419
439
420
- async def __aexit__ (self , exc_type : None , exc_value : None , traceback : None ) -> None :
421
- await self .aclose ()
440
+ async def _decrement_usage (self ) -> int :
441
+ """
442
+ Helper coroutine to decrement the usage counter while holding the lock.
443
+ Returns the new value of the usage counter.
444
+ """
445
+ async with self ._usage_lock :
446
+ self ._usage_counter -= 1
447
+ return self ._usage_counter
448
+
449
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
450
+ """
451
+ Async context manager exit. Decrements a usage counter. If this is the
452
+ last exit (counter becomes zero), the client closes its connection pool.
453
+ """
454
+ current_usage = await asyncio .shield (self ._decrement_usage ())
455
+ if current_usage == 0 :
456
+ # This was the last active context, so disconnect the pool.
457
+ await asyncio .shield (self .aclose ())
422
458
423
459
def __await__ (self ) -> Generator [Any , None , "RedisCluster" ]:
424
460
return self .initialize ().__await__ ()
0 commit comments