Skip to content

Commit 0539a5e

Browse files
authored
Merge branch 'master' into anyio
2 parents 3ca4368 + 688c2b0 commit 0539a5e

File tree

4 files changed

+111
-5
lines changed

4 files changed

+111
-5
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ using `invoke standalone-tests`; similarly, RedisCluster tests can be run by usi
8181
Each run of tests starts and stops the various dockers required. Sometimes
8282
things get stuck, an `invoke clean` can help.
8383
84+
## Linting
85+
86+
Call `invoke linters` to run linters without also running tests.
87+
8488
## Documentation
8589
8690
If relevant, update the code documentation, via docstrings, or in `/docs`.

redis/asyncio/client.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def __init__(
387387
# on a set of redis commands
388388
self._single_conn_lock = asyncio.Lock()
389389

390+
# When used as an async context manager, we need to increment and decrement
391+
# a usage counter so that we can close the connection pool when no one is
392+
# using the client.
393+
self._usage_counter = 0
394+
self._usage_lock = asyncio.Lock()
395+
390396
def __repr__(self):
391397
return (
392398
f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -594,10 +600,47 @@ def client(self) -> "Redis":
594600
)
595601

596602
async def __aenter__(self: _RedisT) -> _RedisT:
597-
return await self.initialize()
603+
"""
604+
Async context manager entry. Increments a usage counter so that the
605+
connection pool is only closed (via aclose()) when no context is using
606+
the client.
607+
"""
608+
await self._increment_usage()
609+
try:
610+
# Initialize the client (i.e. establish connection, etc.)
611+
return await self.initialize()
612+
except Exception:
613+
# If initialization fails, decrement the counter to keep it in sync
614+
await self._decrement_usage()
615+
raise
616+
617+
async def _increment_usage(self) -> int:
618+
"""
619+
Helper coroutine to increment the usage counter while holding the lock.
620+
Returns the new value of the usage counter.
621+
"""
622+
async with self._usage_lock:
623+
self._usage_counter += 1
624+
return self._usage_counter
625+
626+
async def _decrement_usage(self) -> int:
627+
"""
628+
Helper coroutine to decrement the usage counter while holding the lock.
629+
Returns the new value of the usage counter.
630+
"""
631+
async with self._usage_lock:
632+
self._usage_counter -= 1
633+
return self._usage_counter
598634

599635
async def __aexit__(self, exc_type, exc_value, traceback):
600-
await self.aclose()
636+
"""
637+
Async context manager exit. Decrements a usage counter. If this is the
638+
last exit (counter becomes zero), the client closes its connection pool.
639+
"""
640+
current_usage = await asyncio.shield(self._decrement_usage())
641+
if current_usage == 0:
642+
# This was the last active context, so disconnect the pool.
643+
await asyncio.shield(self.aclose())
601644

602645
_DEL_MESSAGE = "Unclosed Redis client"
603646

redis/asyncio/cluster.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,12 @@ def __init__(
431431
self._initialize = True
432432
self._lock: Optional[asyncio.Lock] = None
433433

434+
# When used as an async context manager, we need to increment and decrement
435+
# a usage counter so that we can close the connection pool when no one is
436+
# using the client.
437+
self._usage_counter = 0
438+
self._usage_lock = asyncio.Lock()
439+
434440
async def initialize(self) -> "RedisCluster":
435441
"""Get all nodes from startup nodes & creates connections if not initialized."""
436442
if self._initialize:
@@ -467,10 +473,47 @@ async def close(self) -> None:
467473
await self.aclose()
468474

469475
async def __aenter__(self) -> "RedisCluster":
470-
return await self.initialize()
476+
"""
477+
Async context manager entry. Increments a usage counter so that the
478+
connection pool is only closed (via aclose()) when no context is using
479+
the client.
480+
"""
481+
await self._increment_usage()
482+
try:
483+
# Initialize the client (i.e. establish connection, etc.)
484+
return await self.initialize()
485+
except Exception:
486+
# If initialization fails, decrement the counter to keep it in sync
487+
await self._decrement_usage()
488+
raise
471489

472-
async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
473-
await self.aclose()
490+
async def _increment_usage(self) -> int:
491+
"""
492+
Helper coroutine to increment the usage counter while holding the lock.
493+
Returns the new value of the usage counter.
494+
"""
495+
async with self._usage_lock:
496+
self._usage_counter += 1
497+
return self._usage_counter
498+
499+
async def _decrement_usage(self) -> int:
500+
"""
501+
Helper coroutine to decrement the usage counter while holding the lock.
502+
Returns the new value of the usage counter.
503+
"""
504+
async with self._usage_lock:
505+
self._usage_counter -= 1
506+
return self._usage_counter
507+
508+
async def __aexit__(self, exc_type, exc_value, traceback):
509+
"""
510+
Async context manager exit. Decrements a usage counter. If this is the
511+
last exit (counter becomes zero), the client closes its connection pool.
512+
"""
513+
current_usage = await asyncio.shield(self._decrement_usage())
514+
if current_usage == 0:
515+
# This was the last active context, so disconnect the pool.
516+
await asyncio.shield(self.aclose())
474517

475518
def __await__(self) -> Generator[Any, None, "RedisCluster"]:
476519
return self.initialize().__await__()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_usage_counter(r):
8+
async def dummy_task():
9+
async with r:
10+
await asyncio.sleep(0.01)
11+
12+
tasks = [dummy_task() for _ in range(20)]
13+
await asyncio.gather(*tasks)
14+
15+
# After all tasks have completed, the usage counter should be back to zero.
16+
assert r._usage_counter == 0

0 commit comments

Comments
 (0)