Skip to content

Commit 9fb0376

Browse files
committed
fix scan iter command issued to different replicas
1 parent 2ffcac3 commit 9fb0376

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

redis/asyncio/sentinel.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import random
33
import weakref
4-
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
4+
import uuid
5+
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type, Any
56

67
from redis.asyncio.client import Redis
78
from redis.asyncio.connection import (
@@ -65,6 +66,22 @@ async def connect(self):
6566
self._connect_retry,
6667
lambda error: asyncio.sleep(0),
6768
)
69+
70+
async def _connect_to_address_retry(self, host: str, port: int) -> None:
71+
if self._reader:
72+
return # already connected
73+
try:
74+
return await self.connect_to((host, port))
75+
except ConnectionError as exc:
76+
raise SlaveNotFoundError
77+
78+
async def connect_to_address(self, host: str, port: int) -> None:
79+
# Connect to the specified host and port
80+
# instead of connecting to the master / rotated slaves
81+
return await self.retry.call_with_retry(
82+
lambda: self._connect_to_address_retry(host, port),
83+
lambda error: asyncio.sleep(0),
84+
)
6885

6986
async def read_response(
7087
self,
@@ -122,6 +139,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
122139
self.sentinel_manager = sentinel_manager
123140
self.master_address = None
124141
self.slave_rr_counter = None
142+
self._request_id_to_replica_address = {}
125143

126144
def __repr__(self):
127145
return (
@@ -152,6 +170,11 @@ async def get_master_address(self):
152170

153171
async def rotate_slaves(self) -> AsyncIterator:
154172
"""Round-robin slave balancer"""
173+
(
174+
server_host,
175+
server_port,
176+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
177+
155178
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
156179
if slaves:
157180
if self.slave_rr_counter is None:
@@ -167,6 +190,102 @@ async def rotate_slaves(self) -> AsyncIterator:
167190
pass
168191
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
169192

193+
async def get_connection(
194+
self, command_name: str, *keys: Any, **options: Any
195+
) -> SentinelManagedConnection:
196+
"""
197+
Get a connection from the pool.
198+
`xxx_scan_iter` commands needs to be handled specially.
199+
If the client is created using a connection pool, in replica mode,
200+
all `scan` command-equivalent of the `xxx_scan_iter` commands needs
201+
to be issued to the same Redis replica.
202+
203+
The way each server positions each key is different with one another,
204+
and the cursor acts as the 'offset' of the scan.
205+
Hence, all scans coming from a single xxx_scan_iter_channel command
206+
should go to the same replica.
207+
"""
208+
# If not an iter command or in master mode, call super()
209+
# No custom logic for master, because there's only 1 master.
210+
# The bug is only when Redis has the possibility to connect to multiple replicas
211+
if not (iter_req_id := options.get("_iter_req_id", None)) or self.is_master:
212+
return await super().get_connection(command_name, *keys, **options) # type: ignore[no-any-return]
213+
214+
# Check if this iter request has already been directed to a particular server
215+
# Check if this iter request has already been directed to a particular server
216+
(
217+
server_host,
218+
server_port,
219+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
220+
connection = None
221+
# If this is the first scan request of the iter command,
222+
# get a connection from the pool
223+
if server_host is None or server_port is None:
224+
try:
225+
connection = self._available_connections.pop() # type: ignore [assignment]
226+
except IndexError:
227+
connection = self.make_connection()
228+
# If this is not the first scan request of the iter command
229+
else:
230+
# Check from the available connections, if any of the connection
231+
# is connected to the host and port that we want
232+
# If yes, use that connection
233+
for available_connection in self._available_connections.copy():
234+
if (
235+
available_connection.host == server_host
236+
and available_connection.port == server_port
237+
):
238+
self._available_connections.remove(available_connection)
239+
connection = available_connection # type: ignore[assignment]
240+
# If not, make a new dummy connection object, and set its host and port
241+
# to the one that we want later in the call to ``connect_to_address``
242+
if not connection:
243+
connection = self.make_connection()
244+
assert connection
245+
self._in_use_connections.add(connection)
246+
try:
247+
# ensure this connection is connected to Redis
248+
# If this is the first scan request,
249+
# just call the SentinelManagedConnection.connect()
250+
# This will call rotate_slaves
251+
# and connect to a random replica
252+
if server_port is None or server_port is None:
253+
await connection.connect()
254+
# If this is not the first scan request,
255+
# connect to the particular address and port
256+
else:
257+
# This will connect to the host and port that we've specified above
258+
await connection.connect_to_address(server_host, server_port) # type: ignore[arg-type]
259+
# connections that the pool provides should be ready to send
260+
# a command. if not, the connection was either returned to the
261+
# pool before all data has been read or the socket has been
262+
# closed. either way, reconnect and verify everything is good.
263+
try:
264+
# type ignore below:
265+
# attr Not defined in redis stubs and
266+
# we don't need to create a subclass to help with this single attr
267+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
268+
raise ConnectionError("Connection has data") from None
269+
except (ConnectionError, OSError):
270+
await connection.disconnect()
271+
await connection.connect()
272+
# type ignore below: similar to above
273+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
274+
raise ConnectionError("Connection not ready") from None
275+
except BaseException:
276+
# release the connection back to the pool so that we don't
277+
# leak it
278+
await self.release(connection)
279+
raise
280+
# Store the connection to the dictionary
281+
self._request_id_to_replica_address[iter_req_id] = (
282+
connection.host,
283+
connection.port,
284+
)
285+
286+
return connection
287+
288+
170289

171290
class Sentinel(AsyncSentinelCommands):
172291
"""

redis/commands/core.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# from __future__ import annotations
22

3+
import uuid
34
import datetime
45
import hashlib
56
import warnings
@@ -3050,6 +3051,7 @@ def sscan(
30503051
cursor: int = 0,
30513052
match: Union[PatternT, None] = None,
30523053
count: Union[int, None] = None,
3054+
**kwargs,
30533055
) -> ResponseT:
30543056
"""
30553057
Incrementally return lists of elements in a set. Also return a cursor
@@ -3066,7 +3068,7 @@ def sscan(
30663068
pieces.extend([b"MATCH", match])
30673069
if count is not None:
30683070
pieces.extend([b"COUNT", count])
3069-
return self.execute_command("SSCAN", *pieces)
3071+
return self.execute_command("SSCAN", *pieces,)
30703072

30713073
def sscan_iter(
30723074
self,
@@ -3094,6 +3096,7 @@ def hscan(
30943096
match: Union[PatternT, None] = None,
30953097
count: Union[int, None] = None,
30963098
no_values: Union[bool, None] = None,
3099+
**kwargs,
30973100
) -> ResponseT:
30983101
"""
30993102
Incrementally return key/value slices in a hash. Also return a cursor
@@ -3114,7 +3117,7 @@ def hscan(
31143117
pieces.extend([b"COUNT", count])
31153118
if no_values is not None:
31163119
pieces.extend([b"NOVALUES"])
3117-
return self.execute_command("HSCAN", *pieces, no_values=no_values)
3120+
return self.execute_command("HSCAN", *pieces, no_values=no_values, **kwargs)
31183121

31193122
def hscan_iter(
31203123
self,
@@ -3150,6 +3153,7 @@ def zscan(
31503153
match: Union[PatternT, None] = None,
31513154
count: Union[int, None] = None,
31523155
score_cast_func: Union[type, Callable] = float,
3156+
**kwargs,
31533157
) -> ResponseT:
31543158
"""
31553159
Incrementally return lists of elements in a sorted set. Also return a
@@ -3169,7 +3173,7 @@ def zscan(
31693173
if count is not None:
31703174
pieces.extend([b"COUNT", count])
31713175
options = {"score_cast_func": score_cast_func}
3172-
return self.execute_command("ZSCAN", *pieces, **options)
3176+
return self.execute_command("ZSCAN", *pieces, **options, **kwargs)
31733177

31743178
def zscan_iter(
31753179
self,
@@ -3222,10 +3226,12 @@ async def scan_iter(
32223226
HASH, LIST, SET, STREAM, STRING, ZSET
32233227
Additionally, Redis modules can expose other types as well.
32243228
"""
3229+
# DO NOT inline this statement to the scan call
3230+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32253231
cursor = "0"
32263232
while cursor != 0:
32273233
cursor, data = await self.scan(
3228-
cursor=cursor, match=match, count=count, _type=_type, **kwargs
3234+
cursor=cursor, match=match, count=count, _type=_type, _iter_req_id=iter_req_id, **kwargs
32293235
)
32303236
for d in data:
32313237
yield d
@@ -3244,10 +3250,12 @@ async def sscan_iter(
32443250
32453251
``count`` allows for hint the minimum number of returns
32463252
"""
3253+
# DO NOT inline this statement to the scan call
3254+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32473255
cursor = "0"
32483256
while cursor != 0:
32493257
cursor, data = await self.sscan(
3250-
name, cursor=cursor, match=match, count=count
3258+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
32513259
)
32523260
for d in data:
32533261
yield d
@@ -3269,10 +3277,16 @@ async def hscan_iter(
32693277
32703278
``no_values`` indicates to return only the keys, without values
32713279
"""
3280+
# DO NOT inline this statement to the scan call
3281+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32723282
cursor = "0"
32733283
while cursor != 0:
32743284
cursor, data = await self.hscan(
3285+
<<<<<<< HEAD
32753286
name, cursor=cursor, match=match, count=count, no_values=no_values
3287+
=======
3288+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
3289+
>>>>>>> 39aaec6 (fix scan iter command issued to different replicas)
32763290
)
32773291
if no_values:
32783292
for it in data:
@@ -3298,6 +3312,8 @@ async def zscan_iter(
32983312
32993313
``score_cast_func`` a callable used to cast the score return value
33003314
"""
3315+
# DO NOT inline this statement to the scan call
3316+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
33013317
cursor = "0"
33023318
while cursor != 0:
33033319
cursor, data = await self.zscan(
@@ -3306,6 +3322,7 @@ async def zscan_iter(
33063322
match=match,
33073323
count=count,
33083324
score_cast_func=score_cast_func,
3325+
_iter_req_id=iter_req_id
33093326
)
33103327
for d in data:
33113328
yield d

0 commit comments

Comments
 (0)