Skip to content

Commit 698176d

Browse files
committed
fix scan iter command issued to different replicas
1 parent cd92428 commit 698176d

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

redis/asyncio/sentinel.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import random
33
import weakref
4-
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
4+
import uuid
5+
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type, Any
56

67
from redis.asyncio.client import Redis
78
from redis.asyncio.connection import (
@@ -65,6 +66,22 @@ async def connect(self):
6566
self._connect_retry,
6667
lambda error: asyncio.sleep(0),
6768
)
69+
70+
async def _connect_to_address_retry(self, host: str, port: int) -> None:
71+
if self._reader:
72+
return # already connected
73+
try:
74+
return await self.connect_to((host, port))
75+
except ConnectionError as exc:
76+
raise SlaveNotFoundError
77+
78+
async def connect_to_address(self, host: str, port: int) -> None:
79+
# Connect to the specified host and port
80+
# instead of connecting to the master / rotated slaves
81+
return await self.retry.call_with_retry(
82+
lambda: self._connect_to_address_retry(host, port),
83+
lambda error: asyncio.sleep(0),
84+
)
6885

6986
async def read_response(
7087
self,
@@ -122,6 +139,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
122139
self.sentinel_manager = sentinel_manager
123140
self.master_address = None
124141
self.slave_rr_counter = None
142+
self._request_id_to_replica_address = {}
125143

126144
def __repr__(self):
127145
return (
@@ -152,6 +170,11 @@ async def get_master_address(self):
152170

153171
async def rotate_slaves(self) -> AsyncIterator:
154172
"""Round-robin slave balancer"""
173+
(
174+
server_host,
175+
server_port,
176+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
177+
155178
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
156179
if slaves:
157180
if self.slave_rr_counter is None:
@@ -167,6 +190,102 @@ async def rotate_slaves(self) -> AsyncIterator:
167190
pass
168191
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
169192

193+
async def get_connection(
194+
self, command_name: str, *keys: Any, **options: Any
195+
) -> SentinelManagedConnection:
196+
"""
197+
Get a connection from the pool.
198+
`xxx_scan_iter` commands needs to be handled specially.
199+
If the client is created using a connection pool, in replica mode,
200+
all `scan` command-equivalent of the `xxx_scan_iter` commands needs
201+
to be issued to the same Redis replica.
202+
203+
The way each server positions each key is different with one another,
204+
and the cursor acts as the 'offset' of the scan.
205+
Hence, all scans coming from a single xxx_scan_iter_channel command
206+
should go to the same replica.
207+
"""
208+
# If not an iter command or in master mode, call super()
209+
# No custom logic for master, because there's only 1 master.
210+
# The bug is only when Redis has the possibility to connect to multiple replicas
211+
if not (iter_req_id := options.get("_iter_req_id", None)) or self.is_master:
212+
return await super().get_connection(command_name, *keys, **options) # type: ignore[no-any-return]
213+
214+
# Check if this iter request has already been directed to a particular server
215+
# Check if this iter request has already been directed to a particular server
216+
(
217+
server_host,
218+
server_port,
219+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
220+
connection = None
221+
# If this is the first scan request of the iter command,
222+
# get a connection from the pool
223+
if server_host is None or server_port is None:
224+
try:
225+
connection = self._available_connections.pop() # type: ignore [assignment]
226+
except IndexError:
227+
connection = self.make_connection()
228+
# If this is not the first scan request of the iter command
229+
else:
230+
# Check from the available connections, if any of the connection
231+
# is connected to the host and port that we want
232+
# If yes, use that connection
233+
for available_connection in self._available_connections.copy():
234+
if (
235+
available_connection.host == server_host
236+
and available_connection.port == server_port
237+
):
238+
self._available_connections.remove(available_connection)
239+
connection = available_connection # type: ignore[assignment]
240+
# If not, make a new dummy connection object, and set its host and port
241+
# to the one that we want later in the call to ``connect_to_address``
242+
if not connection:
243+
connection = self.make_connection()
244+
assert connection
245+
self._in_use_connections.add(connection)
246+
try:
247+
# ensure this connection is connected to Redis
248+
# If this is the first scan request,
249+
# just call the SentinelManagedConnection.connect()
250+
# This will call rotate_slaves
251+
# and connect to a random replica
252+
if server_port is None or server_port is None:
253+
await connection.connect()
254+
# If this is not the first scan request,
255+
# connect to the particular address and port
256+
else:
257+
# This will connect to the host and port that we've specified above
258+
await connection.connect_to_address(server_host, server_port) # type: ignore[arg-type]
259+
# connections that the pool provides should be ready to send
260+
# a command. if not, the connection was either returned to the
261+
# pool before all data has been read or the socket has been
262+
# closed. either way, reconnect and verify everything is good.
263+
try:
264+
# type ignore below:
265+
# attr Not defined in redis stubs and
266+
# we don't need to create a subclass to help with this single attr
267+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
268+
raise ConnectionError("Connection has data") from None
269+
except (ConnectionError, OSError):
270+
await connection.disconnect()
271+
await connection.connect()
272+
# type ignore below: similar to above
273+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
274+
raise ConnectionError("Connection not ready") from None
275+
except BaseException:
276+
# release the connection back to the pool so that we don't
277+
# leak it
278+
await self.release(connection)
279+
raise
280+
# Store the connection to the dictionary
281+
self._request_id_to_replica_address[iter_req_id] = (
282+
connection.host,
283+
connection.port,
284+
)
285+
286+
return connection
287+
288+
170289

171290
class Sentinel(AsyncSentinelCommands):
172291
"""

redis/commands/core.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# from __future__ import annotations
22

3+
import uuid
34
import datetime
45
import hashlib
56
import warnings
@@ -3063,6 +3064,7 @@ def sscan(
30633064
cursor: int = 0,
30643065
match: Union[PatternT, None] = None,
30653066
count: Union[int, None] = None,
3067+
**kwargs,
30663068
) -> ResponseT:
30673069
"""
30683070
Incrementally return lists of elements in a set. Also return a cursor
@@ -3079,7 +3081,7 @@ def sscan(
30793081
pieces.extend([b"MATCH", match])
30803082
if count is not None:
30813083
pieces.extend([b"COUNT", count])
3082-
return self.execute_command("SSCAN", *pieces)
3084+
return self.execute_command("SSCAN", *pieces,)
30833085

30843086
def sscan_iter(
30853087
self,
@@ -3107,6 +3109,7 @@ def hscan(
31073109
match: Union[PatternT, None] = None,
31083110
count: Union[int, None] = None,
31093111
no_values: Union[bool, None] = None,
3112+
**kwargs,
31103113
) -> ResponseT:
31113114
"""
31123115
Incrementally return key/value slices in a hash. Also return a cursor
@@ -3127,7 +3130,7 @@ def hscan(
31273130
pieces.extend([b"COUNT", count])
31283131
if no_values is not None:
31293132
pieces.extend([b"NOVALUES"])
3130-
return self.execute_command("HSCAN", *pieces, no_values=no_values)
3133+
return self.execute_command("HSCAN", *pieces, no_values=no_values, **kwargs)
31313134

31323135
def hscan_iter(
31333136
self,
@@ -3163,6 +3166,7 @@ def zscan(
31633166
match: Union[PatternT, None] = None,
31643167
count: Union[int, None] = None,
31653168
score_cast_func: Union[type, Callable] = float,
3169+
**kwargs,
31663170
) -> ResponseT:
31673171
"""
31683172
Incrementally return lists of elements in a sorted set. Also return a
@@ -3182,7 +3186,7 @@ def zscan(
31823186
if count is not None:
31833187
pieces.extend([b"COUNT", count])
31843188
options = {"score_cast_func": score_cast_func}
3185-
return self.execute_command("ZSCAN", *pieces, **options)
3189+
return self.execute_command("ZSCAN", *pieces, **options, **kwargs)
31863190

31873191
def zscan_iter(
31883192
self,
@@ -3235,10 +3239,12 @@ async def scan_iter(
32353239
HASH, LIST, SET, STREAM, STRING, ZSET
32363240
Additionally, Redis modules can expose other types as well.
32373241
"""
3242+
# DO NOT inline this statement to the scan call
3243+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32383244
cursor = "0"
32393245
while cursor != 0:
32403246
cursor, data = await self.scan(
3241-
cursor=cursor, match=match, count=count, _type=_type, **kwargs
3247+
cursor=cursor, match=match, count=count, _type=_type, _iter_req_id=iter_req_id, **kwargs
32423248
)
32433249
for d in data:
32443250
yield d
@@ -3257,10 +3263,12 @@ async def sscan_iter(
32573263
32583264
``count`` allows for hint the minimum number of returns
32593265
"""
3266+
# DO NOT inline this statement to the scan call
3267+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32603268
cursor = "0"
32613269
while cursor != 0:
32623270
cursor, data = await self.sscan(
3263-
name, cursor=cursor, match=match, count=count
3271+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
32643272
)
32653273
for d in data:
32663274
yield d
@@ -3282,10 +3290,16 @@ async def hscan_iter(
32823290
32833291
``no_values`` indicates to return only the keys, without values
32843292
"""
3293+
# DO NOT inline this statement to the scan call
3294+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32853295
cursor = "0"
32863296
while cursor != 0:
32873297
cursor, data = await self.hscan(
3298+
<<<<<<< HEAD
32883299
name, cursor=cursor, match=match, count=count, no_values=no_values
3300+
=======
3301+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
3302+
>>>>>>> 39aaec6 (fix scan iter command issued to different replicas)
32893303
)
32903304
if no_values:
32913305
for it in data:
@@ -3311,6 +3325,8 @@ async def zscan_iter(
33113325
33123326
``score_cast_func`` a callable used to cast the score return value
33133327
"""
3328+
# DO NOT inline this statement to the scan call
3329+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
33143330
cursor = "0"
33153331
while cursor != 0:
33163332
cursor, data = await self.zscan(
@@ -3319,6 +3335,7 @@ async def zscan_iter(
33193335
match=match,
33203336
count=count,
33213337
score_cast_func=score_cast_func,
3338+
_iter_req_id=iter_req_id
33223339
)
33233340
for d in data:
33243341
yield d

0 commit comments

Comments
 (0)