Skip to content

Commit 955df70

Browse files
committed
fix test, apply logic to async cluster client
1 parent 34ce3de commit 955df70

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

redis/asyncio/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ def client(self) -> "Redis":
570570
async def __aenter__(self: _RedisT) -> _RedisT:
571571
"""
572572
Async context manager entry. Increments a usage counter so that the
573-
connection pool is only closed (via aclose()) when no one is using the client.
573+
connection pool is only closed (via aclose()) when no context is using
574+
the client.
574575
"""
575576
async with self._usage_lock:
576577
self._usage_counter += 1

redis/asyncio/cluster.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ def __init__(
379379
self._initialize = True
380380
self._lock: Optional[asyncio.Lock] = None
381381

382+
# When used as an async context manager, we need to increment and decrement
383+
# a usage counter so that we can close the connection pool when no one is
384+
# using the client.
385+
self._usage_counter = 0
386+
self._usage_lock = asyncio.Lock()
387+
382388
async def initialize(self) -> "RedisCluster":
383389
"""Get all nodes from startup nodes & creates connections if not initialized."""
384390
if self._initialize:
@@ -415,10 +421,40 @@ async def close(self) -> None:
415421
await self.aclose()
416422

417423
async def __aenter__(self) -> "RedisCluster":
418-
return await self.initialize()
424+
"""
425+
Async context manager entry. Increments a usage counter so that the
426+
connection pool is only closed (via aclose()) when no context is using
427+
the client.
428+
"""
429+
async with self._usage_lock:
430+
self._usage_counter += 1
431+
try:
432+
# Initialize the client (i.e. establish connection, etc.)
433+
return await self.initialize()
434+
except Exception:
435+
# If initialization fails, decrement the counter to keep it in sync
436+
async with self._usage_lock:
437+
self._usage_counter -= 1
438+
raise
419439

420-
async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
421-
await self.aclose()
440+
async def _decrement_usage(self) -> int:
441+
"""
442+
Helper coroutine to decrement the usage counter while holding the lock.
443+
Returns the new value of the usage counter.
444+
"""
445+
async with self._usage_lock:
446+
self._usage_counter -= 1
447+
return self._usage_counter
448+
449+
async def __aexit__(self, exc_type, exc_value, traceback):
450+
"""
451+
Async context manager exit. Decrements a usage counter. If this is the
452+
last exit (counter becomes zero), the client closes its connection pool.
453+
"""
454+
current_usage = await asyncio.shield(self._decrement_usage())
455+
if current_usage == 0:
456+
# This was the last active context, so disconnect the pool.
457+
await asyncio.shield(self.aclose())
422458

423459
def __await__(self) -> Generator[Any, None, "RedisCluster"]:
424460
return self.initialize().__await__()

tests/test_asyncio/test_usage_counter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
22

33
import pytest
4+
import redis
45

56

67
@pytest.mark.asyncio
7-
async def test_usage_counter(create_redis):
8-
r = await create_redis(decode_responses=True)
9-
8+
async def test_usage_counter(r):
109
async def dummy_task():
1110
async with r:
1211
await asyncio.sleep(0.01)

0 commit comments

Comments
 (0)