Skip to content

Commit 296962c

Browse files
authored
Merge branch 'master' into ckpy312
2 parents 2a81f8a + d3a3ada commit 296962c

File tree

14 files changed

+180
-30
lines changed

14 files changed

+180
-30
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
doctests/* @dmaier-redislabs

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ The Python interface to the Redis key-value store.
1717

1818
---------------------------------------------
1919

20+
## How do I Redis?
21+
22+
[Learn for free at Redis University](https://university.redis.com/)
23+
24+
[Build faster with the Redis Launchpad](https://launchpad.redis.com/)
25+
26+
[Try the Redis Cloud](https://redis.com/try-free/)
27+
28+
[Dive in developer tutorials](https://developer.redis.com/)
29+
30+
[Join the Redis community](https://redis.com/community/)
31+
32+
[Work at Redis](https://redis.com/company/careers/jobs/)
33+
2034
## Installation
2135

2236
Start a redis via docker:

redis/asyncio/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def __del__(
546546
_grl().call_exception_handler(context)
547547
except RuntimeError:
548548
pass
549+
self.connection._close()
549550

550551
async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
551552
"""

redis/asyncio/connection.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66
import ssl
77
import sys
8+
import warnings
89
import weakref
910
from abc import abstractmethod
1011
from itertools import chain
@@ -204,6 +205,24 @@ def __init__(
204205
raise ConnectionError("protocol must be either 2 or 3")
205206
self.protocol = protocol
206207

208+
def __del__(self, _warnings: Any = warnings):
209+
# For some reason, the individual streams don't get properly garbage
210+
# collected and therefore produce no resource warnings. We add one
211+
# here, in the same style as those from the stdlib.
212+
if getattr(self, "_writer", None):
213+
_warnings.warn(
214+
f"unclosed Connection {self!r}", ResourceWarning, source=self
215+
)
216+
self._close()
217+
218+
def _close(self):
219+
"""
220+
Internal method to silently close the connection without waiting
221+
"""
222+
if self._writer:
223+
self._writer.close()
224+
self._writer = self._reader = None
225+
207226
def __repr__(self):
208227
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
209228
return f"{self.__class__.__name__}<{repr_args}>"
@@ -1017,7 +1036,7 @@ def __repr__(self):
10171036

10181037
def reset(self):
10191038
self._available_connections = []
1020-
self._in_use_connections = set()
1039+
self._in_use_connections = weakref.WeakSet()
10211040

10221041
def can_get_connection(self) -> bool:
10231042
"""Return True if a connection can be retrieved from the pool."""
@@ -1027,21 +1046,25 @@ def can_get_connection(self) -> bool:
10271046
)
10281047

10291048
async def get_connection(self, command_name, *keys, **options):
1030-
"""Get a connection from the pool"""
1049+
"""Get a connected connection from the pool"""
1050+
connection = self.get_available_connection()
1051+
try:
1052+
await self.ensure_connection(connection)
1053+
except BaseException:
1054+
await self.release(connection)
1055+
raise
1056+
1057+
return connection
1058+
1059+
def get_available_connection(self):
1060+
"""Get a connection from the pool, without making sure it is connected"""
10311061
try:
10321062
connection = self._available_connections.pop()
10331063
except IndexError:
10341064
if len(self._in_use_connections) >= self.max_connections:
10351065
raise ConnectionError("Too many connections") from None
10361066
connection = self.make_connection()
10371067
self._in_use_connections.add(connection)
1038-
1039-
try:
1040-
await self.ensure_connection(connection)
1041-
except BaseException:
1042-
await self.release(connection)
1043-
raise
1044-
10451068
return connection
10461069

10471070
def get_encoder(self):
@@ -1166,13 +1189,21 @@ def __init__(
11661189
async def get_connection(self, command_name, *keys, **options):
11671190
"""Gets a connection from the pool, blocking until one is available"""
11681191
try:
1169-
async with async_timeout(self.timeout):
1170-
async with self._condition:
1192+
async with self._condition:
1193+
async with async_timeout(self.timeout):
11711194
await self._condition.wait_for(self.can_get_connection)
1172-
return await super().get_connection(command_name, *keys, **options)
1195+
connection = super().get_available_connection()
11731196
except asyncio.TimeoutError as err:
11741197
raise ConnectionError("No connection available.") from err
11751198

1199+
# We now perform the connection check outside of the lock.
1200+
try:
1201+
await self.ensure_connection(connection)
1202+
return connection
1203+
except BaseException:
1204+
await self.release(connection)
1205+
raise
1206+
11761207
async def release(self, connection: AbstractConnection):
11771208
"""Releases the connection back to the pool."""
11781209
async with self._condition:

redis/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
9494
"""
9595

9696
@classmethod
97-
def from_url(cls, url: str, **kwargs) -> None:
97+
def from_url(cls, url: str, **kwargs) -> "Redis":
9898
"""
9999
Return a Redis client object configured from the given URL
100100
@@ -1100,7 +1100,7 @@ def handle_message(self, response, ignore_subscribe_messages=False):
11001100

11011101
def run_in_thread(
11021102
self,
1103-
sleep_time: int = 0,
1103+
sleep_time: float = 0.0,
11041104
daemon: bool = False,
11051105
exception_handler: Optional[Callable] = None,
11061106
) -> "PubSubWorkerThread":

redis/commands/search/field.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class Field:
1313
SORTABLE = "SORTABLE"
1414
NOINDEX = "NOINDEX"
1515
AS = "AS"
16+
GEOSHAPE = "GEOSHAPE"
1617

1718
def __init__(
1819
self,
@@ -91,6 +92,21 @@ def __init__(self, name: str, **kwargs):
9192
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
9293

9394

95+
class GeoShapeField(Field):
96+
"""
97+
GeoShapeField is used to enable within/contain indexing/searching
98+
"""
99+
100+
SPHERICAL = "SPHERICAL"
101+
FLAT = "FLAT"
102+
103+
def __init__(self, name: str, coord_system=None, **kwargs):
104+
args = [Field.GEOSHAPE]
105+
if coord_system:
106+
args.append(coord_system)
107+
Field.__init__(self, name, args=args, **kwargs)
108+
109+
94110
class GeoField(Field):
95111
"""
96112
GeoField is used to define a geo-indexing field in a schema definition

tests/test_asyncio/test_commands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,12 @@ async def test_client_setinfo(self, r: redis.Redis):
370370
info = await r2.client_info()
371371
assert info["lib-name"] == "test2"
372372
assert info["lib-ver"] == "1234"
373+
await r2.aclose()
373374
r3 = redis.asyncio.Redis(lib_name=None, lib_version=None)
374375
info = await r3.client_info()
375376
assert info["lib-name"] == ""
376377
assert info["lib-ver"] == ""
378+
await r3.aclose()
377379

378380
@skip_if_server_version_lt("2.6.9")
379381
@pytest.mark.onlynoncluster

tests/test_asyncio/test_connect.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ async def _handler(reader, writer):
7373
try:
7474
return await _redis_request_handler(reader, writer, stop_event)
7575
finally:
76+
writer.close()
77+
await writer.wait_closed()
7678
finished.set()
7779

7880
if isinstance(server_address, str):

tests/test_asyncio/test_connection.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ async def get_conn(_):
8585

8686
assert init_call_count == 1
8787
assert command_call_count == 2
88+
r.connection = None # it was a Mock
89+
await r.aclose()
8890

8991

9092
@skip_if_server_version_lt("4.0.0")
@@ -143,6 +145,7 @@ async def mock_connect():
143145
conn._connect.side_effect = mock_connect
144146
await conn.connect()
145147
assert conn._connect.call_count == 3
148+
await conn.disconnect()
146149

147150

148151
async def test_connect_without_retry_on_os_error():
@@ -194,6 +197,7 @@ async def test_connection_parse_response_resume(r: redis.Redis):
194197
pytest.fail("didn't receive a response")
195198
assert response
196199
assert i > 0
200+
await conn.disconnect()
197201

198202

199203
@pytest.mark.onlynoncluster
@@ -316,7 +320,8 @@ async def mock_aclose(self):
316320
url: str = request.config.getoption("--redis-url")
317321
r1 = await Redis.from_url(url)
318322
with patch.object(r1, "aclose", mock_aclose):
319-
await r1.close()
323+
with pytest.deprecated_call():
324+
await r1.close()
320325
assert calls == 1
321326

322327
with pytest.deprecated_call():
@@ -436,3 +441,52 @@ async def mock_disconnect(_):
436441

437442
assert called == 0
438443
await pool.disconnect()
444+
445+
446+
async def test_client_garbage_collection(request):
447+
"""
448+
Test that a Redis client will call _close() on any
449+
connection that it holds at time of destruction
450+
"""
451+
452+
url: str = request.config.getoption("--redis-url")
453+
pool = ConnectionPool.from_url(url)
454+
455+
# create a client with a connection from the pool
456+
client = Redis(connection_pool=pool, single_connection_client=True)
457+
await client.initialize()
458+
with mock.patch.object(client, "connection") as a:
459+
# we cannot, in unittests, or from asyncio, reliably trigger garbage collection
460+
# so we must just invoke the handler
461+
with pytest.warns(ResourceWarning):
462+
client.__del__()
463+
assert a._close.called
464+
465+
await client.aclose()
466+
await pool.aclose()
467+
468+
469+
async def test_connection_garbage_collection(request):
470+
"""
471+
Test that a Connection object will call close() on the
472+
stream that it holds.
473+
"""
474+
475+
url: str = request.config.getoption("--redis-url")
476+
pool = ConnectionPool.from_url(url)
477+
478+
# create a client with a connection from the pool
479+
client = Redis(connection_pool=pool, single_connection_client=True)
480+
await client.initialize()
481+
conn = client.connection
482+
483+
with mock.patch.object(conn, "_reader"):
484+
with mock.patch.object(conn, "_writer") as a:
485+
# we cannot, in unittests, or from asyncio, reliably trigger
486+
# garbage collection so we must just invoke the handler
487+
with pytest.warns(ResourceWarning):
488+
conn.__del__()
489+
assert a.close.called
490+
491+
await client.aclose()
492+
await pool.aclose()

tests/test_asyncio/test_cwe_404.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def __init__(self, addr, redis_addr, delay: float = 0.0):
1515
self.send_event = asyncio.Event()
1616
self.server = None
1717
self.task = None
18+
self.cond = asyncio.Condition()
19+
self.running = 0
1820

1921
async def __aenter__(self):
2022
await self.start()
@@ -63,10 +65,10 @@ async def stop(self):
6365
except asyncio.CancelledError:
6466
pass
6567
await self.server.wait_closed()
66-
# do we need to close individual connections too?
67-
# prudently close all async generators
68-
loop = self.server.get_loop()
69-
await loop.shutdown_asyncgens()
68+
# Server does not wait for all spawned tasks. We must do that also to ensure
69+
# that all sockets are closed.
70+
async with self.cond:
71+
await self.cond.wait_for(lambda: self.running == 0)
7072

7173
async def pipe(
7274
self,
@@ -75,6 +77,7 @@ async def pipe(
7577
name="",
7678
event: asyncio.Event = None,
7779
):
80+
self.running += 1
7881
try:
7982
while True:
8083
data = await reader.read(1000)
@@ -94,6 +97,10 @@ async def pipe(
9497
# ignore errors on close pertaining to no event loop. Don't want
9598
# to clutter the test output with errors if being garbage collected
9699
pass
100+
async with self.cond:
101+
self.running -= 1
102+
if self.running == 0:
103+
self.cond.notify_all()
97104

98105

99106
@pytest.mark.onlynoncluster

0 commit comments

Comments
 (0)