Skip to content

Commit 95693c8

Browse files
committed
Fix async safety when Redis client is used as an async context manager
When the async Redis client is used as an async context manager and called from different corotuines, one coroutine can exit, shutting down the client's connection pool, while another coroutine is attempting to use a connection. This results in a connection error, such as: redis.exceptions.ConnectionError: Connection closed by server. Additional locking in `ConnectionPool` resolves the problem but introduces extreme latency due to the locking. Instead, this PR implements a shielded counter that increments as callers enter the async context manager and decrements when they exit. The client then closes its connection pool only after all active contexts exit. Performance is on par with use of the client without a context manager.
1 parent c87e01f commit 95693c8

File tree

3 files changed

+70
-17
lines changed

3 files changed

+70
-17
lines changed

redis/asyncio/client.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,12 @@ def __init__(
362362
# on a set of redis commands
363363
self._single_conn_lock = asyncio.Lock()
364364

365+
# When used as an async context manager, we need to increment and decrement
366+
# a usage counter so that we can close the connection pool when no one is
367+
# using the client.
368+
self._usage_counter = 0
369+
self._usage_lock = asyncio.Lock()
370+
365371
def __repr__(self):
366372
return (
367373
f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -562,10 +568,40 @@ def client(self) -> "Redis":
562568
)
563569

564570
async def __aenter__(self: _RedisT) -> _RedisT:
565-
return await self.initialize()
571+
"""
572+
Async context manager entry. Increments a usage counter so that the
573+
connection pool is only closed (via aclose()) when no one is using the client.
574+
"""
575+
async with self._usage_lock:
576+
self._usage_counter += 1
577+
current_usage = self._usage_counter
578+
try:
579+
# Initialize the client (i.e. establish connection, etc.)
580+
return await self.initialize()
581+
except Exception:
582+
# If initialization fails, decrement the counter to keep it in sync
583+
async with self._usage_lock:
584+
self._usage_counter -= 1
585+
raise
586+
587+
async def _decrement_usage(self) -> int:
588+
"""
589+
Helper coroutine to decrement the usage counter while holding the lock.
590+
Returns the new value of the usage counter.
591+
"""
592+
async with self._usage_lock:
593+
self._usage_counter -= 1
594+
return self._usage_counter
566595

567596
async def __aexit__(self, exc_type, exc_value, traceback):
568-
await self.aclose()
597+
"""
598+
Async context manager exit. Decrements a usage counter. If this is the
599+
last exit (counter becomes zero), the client closes its connection pool.
600+
"""
601+
current_usage = await asyncio.shield(self._decrement_usage())
602+
if current_usage == 0:
603+
# This was the last active context, so disconnect the pool.
604+
await asyncio.shield(self.aclose())
569605

570606
_DEL_MESSAGE = "Unclosed Redis client"
571607

@@ -1347,9 +1383,7 @@ async def _disconnect_reset_raise(self, conn, error):
13471383
# indicates the user should retry this transaction.
13481384
if self.watching:
13491385
await self.aclose()
1350-
raise WatchError(
1351-
"A ConnectionError occurred on while watching one or more keys"
1352-
)
1386+
raise
13531387
# if retry_on_error is not set or the error is not one
13541388
# of the specified error types, raise it
13551389
if (

redis/asyncio/connection.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,19 +1157,20 @@ async def disconnect(self, inuse_connections: bool = True):
11571157
current in use, potentially by other tasks. Otherwise only disconnect
11581158
connections that are idle in the pool.
11591159
"""
1160-
if inuse_connections:
1161-
connections: Iterable[AbstractConnection] = chain(
1162-
self._available_connections, self._in_use_connections
1160+
async with self._lock:
1161+
if inuse_connections:
1162+
connections: Iterable[AbstractConnection] = chain(
1163+
self._available_connections, self._in_use_connections
1164+
)
1165+
else:
1166+
connections = self._available_connections
1167+
resp = await asyncio.gather(
1168+
*(connection.disconnect() for connection in connections),
1169+
return_exceptions=True,
11631170
)
1164-
else:
1165-
connections = self._available_connections
1166-
resp = await asyncio.gather(
1167-
*(connection.disconnect() for connection in connections),
1168-
return_exceptions=True,
1169-
)
1170-
exc = next((r for r in resp if isinstance(r, BaseException)), None)
1171-
if exc:
1172-
raise exc
1171+
exc = next((r for r in resp if isinstance(r, BaseException)), None)
1172+
if exc:
1173+
raise exc
11731174

11741175
async def aclose(self) -> None:
11751176
"""Close the pool, disconnecting all connections"""
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_usage_counter(create_redis):
8+
r = await create_redis(decode_responses=True)
9+
10+
async def dummy_task():
11+
async with r:
12+
await asyncio.sleep(0.01)
13+
14+
tasks = [dummy_task() for _ in range(20)]
15+
await asyncio.gather(*tasks)
16+
17+
# After all tasks have completed, the usage counter should be back to zero.
18+
assert r._usage_counter == 0

0 commit comments

Comments
 (0)