Skip to content

Commit 688c2b0

Browse files
authored
Fix async clients safety when used as an async context manager (#3512)
* Fix async safety when Redis client is used as an async context manager When the async Redis client is used as an async context manager and called from different corotuines, one coroutine can exit, shutting down the client's connection pool, while another coroutine is attempting to use a connection. This results in a connection error, such as: redis.exceptions.ConnectionError: Connection closed by server. Additional locking in `ConnectionPool` resolves the problem but introduces extreme latency due to the locking. Instead, this PR implements a shielded counter that increments as callers enter the async context manager and decrements when they exit. The client then closes its connection pool only after all active contexts exit. Performance is on par with use of the client without a context manager.
1 parent 349d761 commit 688c2b0

File tree

4 files changed

+111
-5
lines changed

4 files changed

+111
-5
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ using `invoke standalone-tests`; similarly, RedisCluster tests can be run by usi
8181
Each run of tests starts and stops the various dockers required. Sometimes
8282
things get stuck, an `invoke clean` can help.
8383
84+
## Linting
85+
86+
Call `invoke linters` to run linters without also running tests.
87+
8488
## Documentation
8589
8690
If relevant, update the code documentation, via docstrings, or in `/docs`.

redis/asyncio/client.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def __init__(
387387
# on a set of redis commands
388388
self._single_conn_lock = asyncio.Lock()
389389

390+
# When used as an async context manager, we need to increment and decrement
391+
# a usage counter so that we can close the connection pool when no one is
392+
# using the client.
393+
self._usage_counter = 0
394+
self._usage_lock = asyncio.Lock()
395+
390396
def __repr__(self):
391397
return (
392398
f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -594,10 +600,47 @@ def client(self) -> "Redis":
594600
)
595601

596602
async def __aenter__(self: _RedisT) -> _RedisT:
597-
return await self.initialize()
603+
"""
604+
Async context manager entry. Increments a usage counter so that the
605+
connection pool is only closed (via aclose()) when no context is using
606+
the client.
607+
"""
608+
await self._increment_usage()
609+
try:
610+
# Initialize the client (i.e. establish connection, etc.)
611+
return await self.initialize()
612+
except Exception:
613+
# If initialization fails, decrement the counter to keep it in sync
614+
await self._decrement_usage()
615+
raise
616+
617+
async def _increment_usage(self) -> int:
618+
"""
619+
Helper coroutine to increment the usage counter while holding the lock.
620+
Returns the new value of the usage counter.
621+
"""
622+
async with self._usage_lock:
623+
self._usage_counter += 1
624+
return self._usage_counter
625+
626+
async def _decrement_usage(self) -> int:
627+
"""
628+
Helper coroutine to decrement the usage counter while holding the lock.
629+
Returns the new value of the usage counter.
630+
"""
631+
async with self._usage_lock:
632+
self._usage_counter -= 1
633+
return self._usage_counter
598634

599635
async def __aexit__(self, exc_type, exc_value, traceback):
600-
await self.aclose()
636+
"""
637+
Async context manager exit. Decrements a usage counter. If this is the
638+
last exit (counter becomes zero), the client closes its connection pool.
639+
"""
640+
current_usage = await asyncio.shield(self._decrement_usage())
641+
if current_usage == 0:
642+
# This was the last active context, so disconnect the pool.
643+
await asyncio.shield(self.aclose())
601644

602645
_DEL_MESSAGE = "Unclosed Redis client"
603646

redis/asyncio/cluster.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,12 @@ def __init__(
431431
self._initialize = True
432432
self._lock: Optional[asyncio.Lock] = None
433433

434+
# When used as an async context manager, we need to increment and decrement
435+
# a usage counter so that we can close the connection pool when no one is
436+
# using the client.
437+
self._usage_counter = 0
438+
self._usage_lock = asyncio.Lock()
439+
434440
async def initialize(self) -> "RedisCluster":
435441
"""Get all nodes from startup nodes & creates connections if not initialized."""
436442
if self._initialize:
@@ -467,10 +473,47 @@ async def close(self) -> None:
467473
await self.aclose()
468474

469475
async def __aenter__(self) -> "RedisCluster":
470-
return await self.initialize()
476+
"""
477+
Async context manager entry. Increments a usage counter so that the
478+
connection pool is only closed (via aclose()) when no context is using
479+
the client.
480+
"""
481+
await self._increment_usage()
482+
try:
483+
# Initialize the client (i.e. establish connection, etc.)
484+
return await self.initialize()
485+
except Exception:
486+
# If initialization fails, decrement the counter to keep it in sync
487+
await self._decrement_usage()
488+
raise
471489

472-
async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
473-
await self.aclose()
490+
async def _increment_usage(self) -> int:
491+
"""
492+
Helper coroutine to increment the usage counter while holding the lock.
493+
Returns the new value of the usage counter.
494+
"""
495+
async with self._usage_lock:
496+
self._usage_counter += 1
497+
return self._usage_counter
498+
499+
async def _decrement_usage(self) -> int:
500+
"""
501+
Helper coroutine to decrement the usage counter while holding the lock.
502+
Returns the new value of the usage counter.
503+
"""
504+
async with self._usage_lock:
505+
self._usage_counter -= 1
506+
return self._usage_counter
507+
508+
async def __aexit__(self, exc_type, exc_value, traceback):
509+
"""
510+
Async context manager exit. Decrements a usage counter. If this is the
511+
last exit (counter becomes zero), the client closes its connection pool.
512+
"""
513+
current_usage = await asyncio.shield(self._decrement_usage())
514+
if current_usage == 0:
515+
# This was the last active context, so disconnect the pool.
516+
await asyncio.shield(self.aclose())
474517

475518
def __await__(self) -> Generator[Any, None, "RedisCluster"]:
476519
return self.initialize().__await__()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_usage_counter(r):
8+
async def dummy_task():
9+
async with r:
10+
await asyncio.sleep(0.01)
11+
12+
tasks = [dummy_task() for _ in range(20)]
13+
await asyncio.gather(*tasks)
14+
15+
# After all tasks have completed, the usage counter should be back to zero.
16+
assert r._usage_counter == 0

0 commit comments

Comments
 (0)