Skip to content

Commit e8f46bf

Browse files
committed
implement in sync client
1 parent 4b8f6c3 commit e8f46bf

File tree

11 files changed

+540
-97
lines changed

11 files changed

+540
-97
lines changed

redis/asyncio/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,8 @@ async def execute_command(self, *args, **options):
651651
finally:
652652
if not self.connection:
653653
await pool.release(conn)
654+
if "ITER" in command_name.upper():
655+
pool.cleanup_scan(iter_req_id=options.get("_iter_req_id", None))
654656

655657
async def parse_response(
656658
self, connection: Connection, command_name: Union[str, bytes], **options

redis/asyncio/connection.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from redis.asyncio.retry import Retry
4040
from redis.backoff import NoBackoff
41-
from redis.connection import DEFAULT_RESP_VERSION
41+
from redis.connection import DEFAULT_RESP_VERSION, ConnectionsIndexer
4242
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
4343
from redis.exceptions import (
4444
AuthenticationError,
@@ -1057,6 +1057,13 @@ class ConnectionPool:
10571057
``connection_class``.
10581058
"""
10591059

1060+
@abstractmethod
1061+
def cleanup_scan(self, **options):
1062+
"""
1063+
Additional cleanup operations that the connection pool might
1064+
need to do after a SCAN ITER family command is executed
1065+
"""
1066+
10601067
@classmethod
10611068
def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
10621069
"""
@@ -1118,7 +1125,7 @@ def __init__(
11181125
self.connection_kwargs = connection_kwargs
11191126
self.max_connections = max_connections
11201127

1121-
self._available_connections: List[AbstractConnection] = []
1128+
self._available_connections: ConnectionsIndexer = ConnectionsIndexer()
11221129
self._in_use_connections: Set[AbstractConnection] = set()
11231130
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
11241131

@@ -1324,3 +1331,6 @@ async def release(self, connection: AbstractConnection):
13241331
async with self._condition:
13251332
await super().release(connection)
13261333
self._condition.notify()
1334+
1335+
def cleanup_scan(self, **options):
1336+
pass

redis/asyncio/sentinel.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
147147
self.sentinel_manager = sentinel_manager
148148
self.master_address = None
149149
self.slave_rr_counter = None
150-
self._request_id_to_replica_address = {}
150+
self._iter_req_id_to_replica_address = {}
151151

152152
def __repr__(self):
153153
return (
@@ -193,6 +193,14 @@ async def rotate_slaves(self) -> AsyncIterator:
193193
pass
194194
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
195195

196+
def cleanup_scan(self, **options):
197+
"""
198+
Remove the SCAN ITER family command's request id from the dictionary
199+
"""
200+
self._iter_req_id_to_replica_address.pop(
201+
options.get("_iter_req_id", None), None
202+
)
203+
196204
async def get_connection(
197205
self, command_name: str, *keys: Any, **options: Any
198206
) -> SentinelManagedConnection:
@@ -217,7 +225,7 @@ async def get_connection(
217225
(
218226
server_host,
219227
server_port,
220-
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
228+
) = self._iter_req_id_to_replica_address.get(iter_req_id, (None, None))
221229
connection = None
222230
# If this is the first scan request of the iter command,
223231
# get a connection from the pool
@@ -228,18 +236,12 @@ async def get_connection(
228236
connection = self.make_connection()
229237
# If this is not the first scan request of the iter command
230238
else:
231-
# Check from the available connections, if any of the connection
232-
# is connected to the host and port that we want
233-
for available_connection in self._available_connections.copy():
234-
# if yes, use that connection
235-
if (
236-
available_connection.host == server_host
237-
and available_connection.port == server_port
238-
):
239-
self._available_connections.remove(available_connection)
240-
connection = available_connection
241-
# If not, make a new dummy connection object, and set its host and port
242-
# to the one that we want later in the call to ``connect_to_address``
239+
# Get the connection that has the same host and port
240+
connection = self._available_connections.get_connection(
241+
host=server_host, port=server_port
242+
)
243+
# If not, make a new dummy connection object, and set its host and
244+
# port to the one that we want later in the call to ``connect_to_address``
243245
if not connection:
244246
connection = self.make_connection()
245247
assert connection
@@ -255,25 +257,14 @@ async def get_connection(
255257
# This will connect to the host and port of the replica
256258
else:
257259
await connection.connect_to_address(server_host, server_port)
258-
# Connections that the pool provides should be ready to send
259-
# a command. If not, the connection was either returned to the
260-
# pool before all data has been read or the socket has been
261-
# closed. Either way, reconnect and verify everything is good.
262-
try:
263-
if await connection.can_read_destructive():
264-
raise ConnectionError("Connection has data") from None
265-
except (ConnectionError, OSError):
266-
await connection.disconnect()
267-
await connection.connect()
268-
if await connection.can_read_destructive():
269-
raise ConnectionError("Connection not ready") from None
260+
self.ensure_connection(connection)
270261
except BaseException:
271262
# Release the connection back to the pool so that we don't
272263
# leak it
273264
await self.release(connection)
274265
raise
275266
# Store the connection to the dictionary
276-
self._request_id_to_replica_address[iter_req_id] = (
267+
self._iter_req_id_to_replica_address[iter_req_id] = (
277268
connection.host,
278269
connection.port,
279270
)

redis/commands/core.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,9 +3041,15 @@ def scan_iter(
30413041
Additionally, Redis modules can expose other types as well.
30423042
"""
30433043
cursor = "0"
3044+
iter_req_id = uuid.uuid4()
30443045
while cursor != 0:
30453046
cursor, data = self.scan(
3046-
cursor=cursor, match=match, count=count, _type=_type, **kwargs
3047+
cursor=cursor,
3048+
match=match,
3049+
count=count,
3050+
_type=_type,
3051+
_iter_req_id=iter_req_id,
3052+
**kwargs,
30473053
)
30483054
yield from data
30493055

@@ -3087,8 +3093,11 @@ def sscan_iter(
30873093
``count`` allows for hint the minimum number of returns
30883094
"""
30893095
cursor = "0"
3096+
iter_req_id = uuid.uuid4()
30903097
while cursor != 0:
3091-
cursor, data = self.sscan(name, cursor=cursor, match=match, count=count)
3098+
cursor, data = self.sscan(
3099+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
3100+
)
30923101
yield from data
30933102

30943103
def hscan(
@@ -3139,9 +3148,15 @@ def hscan_iter(
31393148
``no_values`` indicates to return only the keys, without values
31403149
"""
31413150
cursor = "0"
3151+
iter_req_id = uuid.uuid4()
31423152
while cursor != 0:
31433153
cursor, data = self.hscan(
3144-
name, cursor=cursor, match=match, count=count, no_values=no_values
3154+
name,
3155+
cursor=cursor,
3156+
match=match,
3157+
count=count,
3158+
no_values=no_values,
3159+
_iter_req_id=iter_req_id,
31453160
)
31463161
if no_values:
31473162
yield from data
@@ -3195,13 +3210,15 @@ def zscan_iter(
31953210
``score_cast_func`` a callable used to cast the score return value
31963211
"""
31973212
cursor = "0"
3213+
iter_req_id = uuid.uuid4()
31983214
while cursor != 0:
31993215
cursor, data = self.zscan(
32003216
name,
32013217
cursor=cursor,
32023218
match=match,
32033219
count=count,
32043220
score_cast_func=score_cast_func,
3221+
_iter_req_id=iter_req_id,
32053222
)
32063223
yield from data
32073224

redis/connection.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import threading
77
import weakref
88
from abc import abstractmethod
9+
from collections import defaultdict
910
from itertools import chain
1011
from queue import Empty, Full, LifoQueue
1112
from time import time
12-
from typing import Any, Callable, List, Optional, Sequence, Type, Union
13+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, Union
1314
from urllib.parse import parse_qs, unquote, urlparse
1415

1516
from ._cache import (
@@ -734,6 +735,43 @@ def _host_error(self):
734735
return f"{self.host}:{self.port}"
735736

736737

738+
class ConnectionsIndexer(Iterable):
739+
"""
740+
Data structure that manages a list of available connections which
741+
is also indexed based on the address (ip and port) of the connection
742+
"""
743+
744+
def __init__(self):
745+
self._connections = []
746+
# Map the address to a dictionary of connections
747+
# The inner dictionary is a map between the object id to the object itself
748+
# This is to support O(1) operations on all of the class' methods
749+
self._address_to_connections = defaultdict(dict)
750+
751+
def pop(self):
752+
connection = self._connections.pop()
753+
del self._address_to_connections[(connection.host, connection.port)][
754+
id(connection)
755+
]
756+
return connection
757+
758+
def append(self, connection: Connection):
759+
self._connections.append(connection)
760+
self._address_to_connections[(connection.host, connection.port)][
761+
id(connection)
762+
] = connection
763+
764+
def get_connection(self, host: str, port: int):
765+
try:
766+
connection = self._address_to_connections[(host, port)].popitem()
767+
except KeyError:
768+
return None
769+
return connection
770+
771+
def __iter__(self):
772+
return iter(self._connections)
773+
774+
737775
class SSLConnection(Connection):
738776
"""Manages SSL connections to and from the Redis server(s).
739777
This class extends the Connection class, adding SSL functionality, and making
@@ -1107,7 +1145,7 @@ def __repr__(self) -> (str, str):
11071145
def reset(self) -> None:
11081146
self._lock = threading.Lock()
11091147
self._created_connections = 0
1110-
self._available_connections = []
1148+
self._available_connections = ConnectionsIndexer()
11111149
self._in_use_connections = set()
11121150

11131151
# this must be the last operation in this method. while reset() is
@@ -1168,6 +1206,25 @@ def _checkpid(self) -> None:
11681206
finally:
11691207
self._fork_lock.release()
11701208

1209+
def ensure_connection(self, connection: AbstractConnection):
1210+
# ensure this connection is connected to Redis
1211+
connection.connect()
1212+
# if client caching is not enabled connections that the pool
1213+
# provides should be ready to send a command.
1214+
# if not, the connection was either returned to the
1215+
# pool before all data has been read or the socket has been
1216+
# closed. either way, reconnect and verify everything is good.
1217+
# (if caching enabled the connection will not always be ready
1218+
# to send a command because it may contain invalidation messages)
1219+
try:
1220+
if connection.can_read() and connection.client_cache is None:
1221+
raise ConnectionError("Connection has data")
1222+
except (ConnectionError, OSError):
1223+
connection.disconnect()
1224+
connection.connect()
1225+
if connection.can_read():
1226+
raise ConnectionError("Connection not ready")
1227+
11711228
def get_connection(self, command_name: str, *keys, **options) -> "Connection":
11721229
"Get a connection from the pool"
11731230
self._checkpid()
@@ -1179,23 +1236,7 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
11791236
self._in_use_connections.add(connection)
11801237

11811238
try:
1182-
# ensure this connection is connected to Redis
1183-
connection.connect()
1184-
# if client caching is not enabled connections that the pool
1185-
# provides should be ready to send a command.
1186-
# if not, the connection was either returned to the
1187-
# pool before all data has been read or the socket has been
1188-
# closed. either way, reconnect and verify everything is good.
1189-
# (if caching enabled the connection will not always be ready
1190-
# to send a command because it may contain invalidation messages)
1191-
try:
1192-
if connection.can_read() and connection.client_cache is None:
1193-
raise ConnectionError("Connection has data")
1194-
except (ConnectionError, OSError):
1195-
connection.disconnect()
1196-
connection.connect()
1197-
if connection.can_read():
1198-
raise ConnectionError("Connection not ready")
1239+
self.ensure_connection(connection)
11991240
except BaseException:
12001241
# release the connection back to the pool so that we don't
12011242
# leak it
@@ -1408,20 +1449,7 @@ def get_connection(self, command_name, *keys, **options):
14081449
connection = self.make_connection()
14091450

14101451
try:
1411-
# ensure this connection is connected to Redis
1412-
connection.connect()
1413-
# connections that the pool provides should be ready to send
1414-
# a command. if not, the connection was either returned to the
1415-
# pool before all data has been read or the socket has been
1416-
# closed. either way, reconnect and verify everything is good.
1417-
try:
1418-
if connection.can_read():
1419-
raise ConnectionError("Connection has data")
1420-
except (ConnectionError, OSError):
1421-
connection.disconnect()
1422-
connection.connect()
1423-
if connection.can_read():
1424-
raise ConnectionError("Connection not ready")
1452+
self.ensure_connection(connection)
14251453
except BaseException:
14261454
# release the connection back to the pool so that we don't leak it
14271455
self.release(connection)

0 commit comments

Comments
 (0)