Skip to content

Commit 4978d50

Browse files
committed
more features for ConnectionsINdexer
1 parent e8f46bf commit 4978d50

File tree

3 files changed

+67
-16
lines changed

3 files changed

+67
-16
lines changed

redis/asyncio/sentinel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def get_connection(
257257
# This will connect to the host and port of the replica
258258
else:
259259
await connection.connect_to_address(server_host, server_port)
260-
self.ensure_connection(connection)
260+
await self.ensure_connection(connection)
261261
except BaseException:
262262
# Release the connection back to the pool so that we don't
263263
# leak it

redis/connection.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -737,39 +737,53 @@ def _host_error(self):
737737

738738
class ConnectionsIndexer(Iterable):
739739
"""
740-
Data structure that manages a list of available connections which
741-
is also indexed based on the address (ip and port) of the connection
740+
Data structure that simulates a list of available connections.
741+
Instead of list, we keep 2 additional DS to support O(1) operations
742+
on all of the class' methods.
743+
The first DS is indexed on the connection object's ID.
744+
The second DS is indexed on the address (ip and port) of the connection.
742745
"""
743746

744747
def __init__(self):
745-
self._connections = []
748+
# Map the id to the connection object
749+
self._id_to_connection = {}
746750
# Map the address to a dictionary of connections
747751
# The inner dictionary is a map between the object id to the object itself
748-
# This is to support O(1) operations on all of the class' methods
752+
# Both of these DS support O(1) operations on all of the class' methods
749753
self._address_to_connections = defaultdict(dict)
750754

751755
def pop(self):
752-
connection = self._connections.pop()
753-
del self._address_to_connections[(connection.host, connection.port)][
754-
id(connection)
755-
]
756+
try:
757+
_, connection = self._id_to_connection.popitem()
758+
del self._address_to_connections[(connection.host, connection.port)][
759+
id(connection)
760+
]
761+
except KeyError:
762+
# We are simulating a list, hence we raise IndexError
763+
# when there's no item in the dictionary
764+
raise IndexError()
756765
return connection
757766

758767
def append(self, connection: Connection):
759-
self._connections.append(connection)
768+
self._id_to_connection[id(connection)] = connection
760769
self._address_to_connections[(connection.host, connection.port)][
761770
id(connection)
762771
] = connection
763772

764773
def get_connection(self, host: str, port: int):
765774
try:
766-
connection = self._address_to_connections[(host, port)].popitem()
775+
_, connection = self._address_to_connections[(host, port)].popitem()
776+
del self._id_to_connection[id(connection)]
767777
except KeyError:
768778
return None
769779
return connection
770780

771781
def __iter__(self):
772-
return iter(self._connections)
782+
# This is an O(1) operation in python3.7 and later
783+
return iter(self._id_to_connection.values())
784+
785+
def __len__(self):
786+
return len(self._id_to_connection)
773787

774788

775789
class SSLConnection(Connection):

redis/sentinel.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,16 @@ def connect_to(self, address):
4141
if str_if_bytes(self.read_response()) != "PONG":
4242
raise ConnectionError("PING failed")
4343

44-
def _connect_retry(self):
44+
def _connect_retry(self, same_server: bool = False):
4545
if self._sock:
4646
return # already connected
47+
# If same_server is True, it means that the connection
48+
# is not rotating to the next slave (if the connection pool is not master)
49+
if same_server:
50+
self.connect_to(self.host, self.port)
51+
return
52+
# If same_server is False, connnect to master in master mode
53+
# and rotate to the next slave in slave mode
4754
if self.connection_pool.is_master:
4855
self.connect_to(self.connection_pool.get_master_address())
4956
else:
@@ -54,8 +61,11 @@ def _connect_retry(self):
5461
continue
5562
raise SlaveNotFoundError # Never be here
5663

57-
def connect(self):
58-
return self.retry.call_with_retry(self._connect_retry, lambda error: None)
64+
def connect(self, same_server: bool = False):
65+
return self.retry.call_with_retry(
66+
lambda: self._connect_retry(same_server),
67+
lambda error: None
68+
)
5969

6070
def read_response(
6171
self,
@@ -195,6 +205,29 @@ def rotate_slaves(self):
195205
"Round-robin slave balancer"
196206
return self.proxy.rotate_slaves()
197207

208+
def ensure_connection_connected_to_address(self, connection: SentinelManagedConnection):
209+
"""
210+
Ensure the connection is already connected to the server that this connection
211+
object wants to connect to
212+
213+
Similar to self.ensure_connection, but calling connection.connect()
214+
in SentinelManagedConnection (replica mode) will cause the
215+
connection object to connect to the next replica in rotation,
216+
and we don't wnat behavior. Look at get_connection inline docs for details.
217+
218+
Here, we just try to make sure that the connection is already connected
219+
to the replica we wanted it to.
220+
"""
221+
connection.connect(same_address=True)
222+
try:
223+
if connection.can_read(same_address=True) and connection.client_cache is None:
224+
raise ConnectionError("Connection has data")
225+
except (ConnectionError, OSError):
226+
connection.disconnect()
227+
connection.connect(same_address=True)
228+
if connection.can_read():
229+
raise ConnectionError("Connection has data")
230+
198231
def cleanup_scan(self, **options):
199232
"""
200233
Remove the SCAN ITER family command's request id from the dictionary
@@ -248,6 +281,7 @@ def get_connection(
248281
connection = self.make_connection()
249282
assert connection
250283
self._in_use_connections.add(connection)
284+
breakpoint()
251285
try:
252286
# Ensure this connection is connected to Redis
253287
# If this is the first scan request, it will
@@ -259,7 +293,9 @@ def get_connection(
259293
# This will connect to the host and port of the replica
260294
else:
261295
connection.connect_to_address(server_host, server_port)
262-
self.ensure_connection(connection)
296+
breakpoint()
297+
self.ensure_connection_connected_to_address(connection)
298+
breakpoint()
263299
except BaseException:
264300
# Release the connection back to the pool so that we don't
265301
# leak it
@@ -270,6 +306,7 @@ def get_connection(
270306
connection.host,
271307
connection.port,
272308
)
309+
breakpoint()
273310
return connection
274311

275312

0 commit comments

Comments
 (0)