@@ -379,6 +379,12 @@ def __init__(
379379 self ._initialize = True
380380 self ._lock : Optional [asyncio .Lock ] = None
381381
382+ # When used as an async context manager, we need to increment and decrement
383+ # a usage counter so that we can close the connection pool when no one is
384+ # using the client.
385+ self ._usage_counter = 0
386+ self ._usage_lock = asyncio .Lock ()
387+
382388 async def initialize (self ) -> "RedisCluster" :
383389 """Get all nodes from startup nodes & creates connections if not initialized."""
384390 if self ._initialize :
@@ -415,10 +421,40 @@ async def close(self) -> None:
415421 await self .aclose ()
416422
417423 async def __aenter__ (self ) -> "RedisCluster" :
418- return await self .initialize ()
424+ """
425+ Async context manager entry. Increments a usage counter so that the
426+ connection pool is only closed (via aclose()) when no context is using
427+ the client.
428+ """
429+ async with self ._usage_lock :
430+ self ._usage_counter += 1
431+ try :
432+ # Initialize the client (i.e. establish connection, etc.)
433+ return await self .initialize ()
434+ except Exception :
435+ # If initialization fails, decrement the counter to keep it in sync
436+ async with self ._usage_lock :
437+ self ._usage_counter -= 1
438+ raise
419439
420- async def __aexit__ (self , exc_type : None , exc_value : None , traceback : None ) -> None :
421- await self .aclose ()
440+ async def _decrement_usage (self ) -> int :
441+ """
442+ Helper coroutine to decrement the usage counter while holding the lock.
443+ Returns the new value of the usage counter.
444+ """
445+ async with self ._usage_lock :
446+ self ._usage_counter -= 1
447+ return self ._usage_counter
448+
449+ async def __aexit__ (self , exc_type , exc_value , traceback ):
450+ """
451+ Async context manager exit. Decrements a usage counter. If this is the
452+ last exit (counter becomes zero), the client closes its connection pool.
453+ """
454+ current_usage = await asyncio .shield (self ._decrement_usage ())
455+ if current_usage == 0 :
456+ # This was the last active context, so disconnect the pool.
457+ await asyncio .shield (self .aclose ())
422458
423459 def __await__ (self ) -> Generator [Any , None , "RedisCluster" ]:
424460 return self .initialize ().__await__ ()
0 commit comments