Skip to content

Commit 39aaec6

Browse files
committed
fix scan iter command issued to different replicas
1 parent 07fc339 commit 39aaec6

File tree

2 files changed

+136
-4
lines changed

2 files changed

+136
-4
lines changed

redis/asyncio/sentinel.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import random
33
import weakref
4-
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
4+
import uuid
5+
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type, Any
56

67
from redis.asyncio.client import Redis
78
from redis.asyncio.connection import (
@@ -65,6 +66,22 @@ async def connect(self):
6566
self._connect_retry,
6667
lambda error: asyncio.sleep(0),
6768
)
69+
70+
async def _connect_to_address_retry(self, host: str, port: int) -> None:
71+
if self._reader:
72+
return # already connected
73+
try:
74+
return await self.connect_to((host, port))
75+
except ConnectionError as exc:
76+
raise SlaveNotFoundError
77+
78+
async def connect_to_address(self, host: str, port: int) -> None:
79+
# Connect to the specified host and port
80+
# instead of connecting to the master / rotated slaves
81+
return await self.retry.call_with_retry(
82+
lambda: self._connect_to_address_retry(host, port),
83+
lambda error: asyncio.sleep(0),
84+
)
6885

6986
async def read_response(
7087
self,
@@ -122,6 +139,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
122139
self.sentinel_manager = sentinel_manager
123140
self.master_address = None
124141
self.slave_rr_counter = None
142+
self._request_id_to_replica_address = {}
125143

126144
def __repr__(self):
127145
return (
@@ -152,6 +170,11 @@ async def get_master_address(self):
152170

153171
async def rotate_slaves(self) -> AsyncIterator:
154172
"""Round-robin slave balancer"""
173+
(
174+
server_host,
175+
server_port,
176+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
177+
155178
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
156179
if slaves:
157180
if self.slave_rr_counter is None:
@@ -167,6 +190,102 @@ async def rotate_slaves(self) -> AsyncIterator:
167190
pass
168191
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
169192

193+
async def get_connection(
194+
self, command_name: str, *keys: Any, **options: Any
195+
) -> SentinelManagedConnection:
196+
"""
197+
Get a connection from the pool.
198+
`xxx_scan_iter` commands needs to be handled specially.
199+
If the client is created using a connection pool, in replica mode,
200+
all `scan` command-equivalent of the `xxx_scan_iter` commands needs
201+
to be issued to the same Redis replica.
202+
203+
The way each server positions each key is different with one another,
204+
and the cursor acts as the 'offset' of the scan.
205+
Hence, all scans coming from a single xxx_scan_iter_channel command
206+
should go to the same replica.
207+
"""
208+
# If not an iter command or in master mode, call super()
209+
# No custom logic for master, because there's only 1 master.
210+
# The bug is only when Redis has the possibility to connect to multiple replicas
211+
if not (iter_req_id := options.get("_iter_req_id", None)) or self.is_master:
212+
return await super().get_connection(command_name, *keys, **options) # type: ignore[no-any-return]
213+
214+
# Check if this iter request has already been directed to a particular server
215+
# Check if this iter request has already been directed to a particular server
216+
(
217+
server_host,
218+
server_port,
219+
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
220+
connection = None
221+
# If this is the first scan request of the iter command,
222+
# get a connection from the pool
223+
if server_host is None or server_port is None:
224+
try:
225+
connection = self._available_connections.pop() # type: ignore [assignment]
226+
except IndexError:
227+
connection = self.make_connection()
228+
# If this is not the first scan request of the iter command
229+
else:
230+
# Check from the available connections, if any of the connection
231+
# is connected to the host and port that we want
232+
# If yes, use that connection
233+
for available_connection in self._available_connections.copy():
234+
if (
235+
available_connection.host == server_host
236+
and available_connection.port == server_port
237+
):
238+
self._available_connections.remove(available_connection)
239+
connection = available_connection # type: ignore[assignment]
240+
# If not, make a new dummy connection object, and set its host and port
241+
# to the one that we want later in the call to ``connect_to_address``
242+
if not connection:
243+
connection = self.make_connection()
244+
assert connection
245+
self._in_use_connections.add(connection)
246+
try:
247+
# ensure this connection is connected to Redis
248+
# If this is the first scan request,
249+
# just call the SentinelManagedConnection.connect()
250+
# This will call rotate_slaves
251+
# and connect to a random replica
252+
if server_port is None or server_port is None:
253+
await connection.connect()
254+
# If this is not the first scan request,
255+
# connect to the particular address and port
256+
else:
257+
# This will connect to the host and port that we've specified above
258+
await connection.connect_to_address(server_host, server_port) # type: ignore[arg-type]
259+
# connections that the pool provides should be ready to send
260+
# a command. if not, the connection was either returned to the
261+
# pool before all data has been read or the socket has been
262+
# closed. either way, reconnect and verify everything is good.
263+
try:
264+
# type ignore below:
265+
# attr Not defined in redis stubs and
266+
# we don't need to create a subclass to help with this single attr
267+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
268+
raise ConnectionError("Connection has data") from None
269+
except (ConnectionError, OSError):
270+
await connection.disconnect()
271+
await connection.connect()
272+
# type ignore below: similar to above
273+
if await connection.can_read_destructive(): # type: ignore[attr-defined]
274+
raise ConnectionError("Connection not ready") from None
275+
except BaseException:
276+
# release the connection back to the pool so that we don't
277+
# leak it
278+
await self.release(connection)
279+
raise
280+
# Store the connection to the dictionary
281+
self._request_id_to_replica_address[iter_req_id] = (
282+
connection.host,
283+
connection.port,
284+
)
285+
286+
return connection
287+
288+
170289

171290
class Sentinel(AsyncSentinelCommands):
172291
"""

redis/commands/core.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# from __future__ import annotations
22

3+
import uuid
34
import datetime
45
import hashlib
56
import warnings
@@ -3059,6 +3060,7 @@ def sscan(
30593060
cursor: int = 0,
30603061
match: Union[PatternT, None] = None,
30613062
count: Union[int, None] = None,
3063+
**kwargs,
30623064
) -> ResponseT:
30633065
"""
30643066
Incrementally return lists of elements in a set. Also return a cursor
@@ -3102,6 +3104,7 @@ def hscan(
31023104
cursor: int = 0,
31033105
match: Union[PatternT, None] = None,
31043106
count: Union[int, None] = None,
3107+
**kwargs,
31053108
) -> ResponseT:
31063109
"""
31073110
Incrementally return key/value slices in a hash. Also return a cursor
@@ -3146,6 +3149,7 @@ def zscan(
31463149
match: Union[PatternT, None] = None,
31473150
count: Union[int, None] = None,
31483151
score_cast_func: Union[type, Callable] = float,
3152+
**kwargs,
31493153
) -> ResponseT:
31503154
"""
31513155
Incrementally return lists of elements in a sorted set. Also return a
@@ -3218,10 +3222,12 @@ async def scan_iter(
32183222
HASH, LIST, SET, STREAM, STRING, ZSET
32193223
Additionally, Redis modules can expose other types as well.
32203224
"""
3225+
# DO NOT inline this statement to the scan call
3226+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32213227
cursor = "0"
32223228
while cursor != 0:
32233229
cursor, data = await self.scan(
3224-
cursor=cursor, match=match, count=count, _type=_type, **kwargs
3230+
cursor=cursor, match=match, count=count, _type=_type, _iter_req_id=iter_req_id, **kwargs
32253231
)
32263232
for d in data:
32273233
yield d
@@ -3240,10 +3246,12 @@ async def sscan_iter(
32403246
32413247
``count`` allows for hint the minimum number of returns
32423248
"""
3249+
# DO NOT inline this statement to the scan call
3250+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32433251
cursor = "0"
32443252
while cursor != 0:
32453253
cursor, data = await self.sscan(
3246-
name, cursor=cursor, match=match, count=count
3254+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
32473255
)
32483256
for d in data:
32493257
yield d
@@ -3262,10 +3270,12 @@ async def hscan_iter(
32623270
32633271
``count`` allows for hint the minimum number of returns
32643272
"""
3273+
# DO NOT inline this statement to the scan call
3274+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32653275
cursor = "0"
32663276
while cursor != 0:
32673277
cursor, data = await self.hscan(
3268-
name, cursor=cursor, match=match, count=count
3278+
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
32693279
)
32703280
for it in data.items():
32713281
yield it
@@ -3287,6 +3297,8 @@ async def zscan_iter(
32873297
32883298
``score_cast_func`` a callable used to cast the score return value
32893299
"""
3300+
# DO NOT inline this statement to the scan call
3301+
iter_req_id = uuid.uuid4() # each iter command should have an ID to maintain connection to the same replica
32903302
cursor = "0"
32913303
while cursor != 0:
32923304
cursor, data = await self.zscan(
@@ -3295,6 +3307,7 @@ async def zscan_iter(
32953307
match=match,
32963308
count=count,
32973309
score_cast_func=score_cast_func,
3310+
_iter_req_id=iter_req_id
32983311
)
32993312
for d in data:
33003313
yield d

0 commit comments

Comments
 (0)