Skip to content

Commit 0548e64

Browse files
Add Redis readiness verification (#3555)
1 parent 8f56b52 commit 0548e64

File tree

12 files changed

+467
-134
lines changed

12 files changed

+467
-134
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __init__(
230230
encoding: str = "utf-8",
231231
encoding_errors: str = "strict",
232232
decode_responses: bool = False,
233+
check_server_ready: bool = False,
233234
retry_on_timeout: bool = False,
234235
retry: Retry = Retry(
235236
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -279,6 +280,10 @@ def __init__(
279280
280281
When 'connection_pool' is provided - the retry configuration of the
281282
provided pool will be used.
283+
284+
Args:
285+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
286+
connect and send operations work even when Redis server is not ready.
282287
"""
283288
kwargs: Dict[str, Any]
284289
if event_dispatcher is None:
@@ -313,6 +318,7 @@ def __init__(
313318
"encoding": encoding,
314319
"encoding_errors": encoding_errors,
315320
"decode_responses": decode_responses,
321+
"check_server_ready": check_server_ready,
316322
"retry_on_error": retry_on_error,
317323
"retry": copy.deepcopy(retry),
318324
"max_connections": max_connections,

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def __init__(
290290
encoding_errors: str = "strict",
291291
decode_responses: bool = False,
292292
# Connection related kwargs
293+
check_server_ready: bool = False,
293294
health_check_interval: float = 0,
294295
socket_connect_timeout: Optional[float] = None,
295296
socket_keepalive: bool = False,
@@ -345,6 +346,7 @@ def __init__(
345346
"encoding_errors": encoding_errors,
346347
"decode_responses": decode_responses,
347348
# Connection related kwargs
349+
"check_server_ready": check_server_ready,
348350
"health_check_interval": health_check_interval,
349351
"socket_connect_timeout": socket_connect_timeout,
350352
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
encoding_errors: str = "strict",
150150
decode_responses: bool = False,
151151
parser_class: Type[BaseParser] = DefaultParser,
152+
check_server_ready: bool = False,
152153
socket_read_size: int = 65536,
153154
health_check_interval: float = 0,
154155
client_name: Optional[str] = None,
@@ -205,6 +206,7 @@ def __init__(
205206
self.health_check_interval = health_check_interval
206207
self.next_health_check: float = -1
207208
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
209+
self.check_server_ready = check_server_ready
208210
self.redis_connect_func = redis_connect_func
209211
self._reader: Optional[asyncio.StreamReader] = None
210212
self._writer: Optional[asyncio.StreamWriter] = None
@@ -304,11 +306,13 @@ async def connect_check_health(
304306
try:
305307
if retry_socket_connect:
306308
await self.retry.call_with_retry(
307-
lambda: self._connect(), lambda error: self.disconnect()
309+
lambda: self._connect_check_server_ready(),
310+
lambda error: self.disconnect(),
308311
)
309312
else:
310-
await self._connect()
313+
await self._connect_check_server_ready()
311314
except asyncio.CancelledError:
315+
self._close()
312316
raise # in 3.7 and earlier, this is an Exception, not BaseException
313317
except (socket.timeout, asyncio.TimeoutError):
314318
raise TimeoutError("Timeout connecting to server")
@@ -343,6 +347,33 @@ async def connect_check_health(
343347
if task and inspect.isawaitable(task):
344348
await task
345349

350+
async def _connect_check_server_ready(self):
351+
await self._connect()
352+
353+
# Doing handshake since connect and send operations work even when Redis is not ready
354+
if self.check_server_ready:
355+
try:
356+
await self.send_command("PING", check_health=False)
357+
358+
if self.socket_timeout is not None:
359+
async with async_timeout(self.socket_timeout):
360+
response = str_if_bytes(await self._reader.read(1024))
361+
else:
362+
response = str_if_bytes(await self._reader.read(1024))
363+
364+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
365+
raise ResponseError(f"Invalid PING response: {response}")
366+
except (
367+
socket.timeout,
368+
asyncio.TimeoutError,
369+
ResponseError,
370+
ConnectionResetError,
371+
) as e:
372+
# `socket_keepalive_options` might contain invalid options
373+
# causing an error. Do not leave the connection open.
374+
self._close()
375+
raise ConnectionError(self._error_message(e))
376+
346377
@abstractmethod
347378
async def _connect(self):
348379
pass
@@ -532,8 +563,7 @@ async def send_packed_command(
532563
self._send_packed_command(command), self.socket_timeout
533564
)
534565
else:
535-
self._writer.writelines(command)
536-
await self._writer.drain()
566+
await self._send_packed_command(command)
537567
except asyncio.TimeoutError:
538568
await self.disconnect(nowait=True)
539569
raise TimeoutError("Timeout writing to socket") from None
@@ -776,7 +806,7 @@ async def _connect(self):
776806
except (OSError, TypeError):
777807
# `socket_keepalive_options` might contain invalid options
778808
# causing an error. Do not leave the connection open.
779-
writer.close()
809+
self._close()
780810
raise
781811

782812
def _host_error(self) -> str:
@@ -961,7 +991,6 @@ async def _connect(self):
961991
reader, writer = await asyncio.open_unix_connection(path=self.path)
962992
self._reader = reader
963993
self._writer = writer
964-
await self.on_connect()
965994

966995
def _host_error(self) -> str:
967996
return self.path

redis/client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,9 @@ def __init__(
274274
provided pool will be used.
275275
276276
Args:
277-
278-
single_connection_client:
279-
if `True`, connection pool is not used. In that case `Redis`
280-
instance use is not thread safe.
277+
single_connection_client:
278+
if `True`, connection pool is not used. In that case `Redis`
279+
instance use is not thread safe.
281280
"""
282281
if event_dispatcher is None:
283282
self._event_dispatcher = EventDispatcher()

redis/connection.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -575,17 +575,17 @@ def connect_check_health(
575575
return
576576
try:
577577
if retry_socket_connect:
578-
sock = self.retry.call_with_retry(
579-
lambda: self._connect(), lambda error: self.disconnect(error)
578+
self.retry.call_with_retry(
579+
lambda: self._connect(),
580+
lambda error: self.disconnect(error),
580581
)
581582
else:
582-
sock = self._connect()
583+
self._connect()
583584
except socket.timeout:
584585
raise TimeoutError("Timeout connecting to server")
585586
except OSError as e:
586587
raise ConnectionError(self._error_message(e))
587588

588-
self._sock = sock
589589
try:
590590
if self.redis_connect_func is None:
591591
# Use the default on_connect function
@@ -608,7 +608,7 @@ def connect_check_health(
608608
callback(self)
609609

610610
@abstractmethod
611-
def _connect(self):
611+
def _connect(self) -> None:
612612
pass
613613

614614
@abstractmethod
@@ -626,6 +626,12 @@ def on_connect_check_health(self, check_health: bool = True):
626626
self._parser.on_connect(self)
627627
parser = self._parser
628628

629+
if check_health:
630+
self.retry.call_with_retry(
631+
lambda: self._send_ping(),
632+
lambda error: self.disconnect(error),
633+
)
634+
629635
auth_args = None
630636
# if credential provider or username and/or password are set, authenticate
631637
if self.credential_provider or (self.username or self.password):
@@ -680,7 +686,7 @@ def on_connect_check_health(self, check_health: bool = True):
680686
# update cluster exception classes
681687
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
682688
self._parser.on_connect(self)
683-
self.send_command("HELLO", self.protocol, check_health=check_health)
689+
self.send_command("HELLO", self.protocol, check_health=False)
684690
self.handshake_metadata = self.read_response()
685691
if (
686692
self.handshake_metadata.get(b"proto") != self.protocol
@@ -711,7 +717,7 @@ def on_connect_check_health(self, check_health: bool = True):
711717
"ON",
712718
"moving-endpoint-type",
713719
endpoint_type.value,
714-
check_health=check_health,
720+
check_health=False,
715721
)
716722
response = self.read_response()
717723
if str_if_bytes(response) != "OK":
@@ -737,7 +743,7 @@ def on_connect_check_health(self, check_health: bool = True):
737743
"CLIENT",
738744
"SETNAME",
739745
self.client_name,
740-
check_health=check_health,
746+
check_health=False,
741747
)
742748
if str_if_bytes(self.read_response()) != "OK":
743749
raise ConnectionError("Error setting client name")
@@ -750,7 +756,7 @@ def on_connect_check_health(self, check_health: bool = True):
750756
"SETINFO",
751757
"LIB-NAME",
752758
self.lib_name,
753-
check_health=check_health,
759+
check_health=False,
754760
)
755761
self.read_response()
756762
except ResponseError:
@@ -763,15 +769,15 @@ def on_connect_check_health(self, check_health: bool = True):
763769
"SETINFO",
764770
"LIB-VER",
765771
self.lib_version,
766-
check_health=check_health,
772+
check_health=False,
767773
)
768774
self.read_response()
769775
except ResponseError:
770776
pass
771777

772778
# if a database is specified, switch to it
773779
if self.db:
774-
self.send_command("SELECT", self.db, check_health=check_health)
780+
self.send_command("SELECT", self.db, check_health=False)
775781
if str_if_bytes(self.read_response()) != "OK":
776782
raise ConnectionError("Invalid Database")
777783

@@ -800,8 +806,16 @@ def disconnect(self, *args):
800806
def _send_ping(self):
801807
"""Send PING, expect PONG in return"""
802808
self.send_command("PING", check_health=False)
803-
if str_if_bytes(self.read_response()) != "PONG":
804-
raise ConnectionError("Bad response from PING health check")
809+
try:
810+
# Do not disconnect on error here, since we want to keep the connection in case of AuthenticationError
811+
# since we are raising ConnectionError in all other cases and ping_failed already disconnects,
812+
# connection reload is already handled
813+
if str_if_bytes(self.read_response(disconnect_on_error=False)) != "PONG":
814+
raise ConnectionError("Bad response from PING health check")
815+
except AuthenticationError:
816+
# if we get an authentication error, the server is healthy
817+
self._parser.on_disconnect()
818+
self._parser.on_connect(self)
805819

806820
def _ping_failed(self, error):
807821
"""Function to call when PING fails"""
@@ -1097,7 +1111,7 @@ def repr_pieces(self):
10971111
pieces.append(("client_name", self.client_name))
10981112
return pieces
10991113

1100-
def _connect(self):
1114+
def _connect(self) -> None:
11011115
"Create a TCP socket connection"
11021116
# we want to mimic what socket.create_connection does to support
11031117
# ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1142,8 @@ def _connect(self):
11281142

11291143
# set the socket_timeout now that we're connected
11301144
sock.settimeout(self.socket_timeout)
1131-
return sock
1145+
self._sock = sock
1146+
return
11321147

11331148
except OSError as _:
11341149
err = _
@@ -1448,15 +1463,15 @@ def __init__(
14481463
self.ssl_ciphers = ssl_ciphers
14491464
super().__init__(**kwargs)
14501465

1451-
def _connect(self):
1466+
def _connect(self) -> None:
14521467
"""
14531468
Wrap the socket with SSL support, handling potential errors.
14541469
"""
1455-
sock = super()._connect()
1470+
super()._connect()
14561471
try:
1457-
return self._wrap_socket_with_ssl(sock)
1472+
self._sock = self._wrap_socket_with_ssl(self._sock)
14581473
except (OSError, RedisError):
1459-
sock.close()
1474+
self._sock.close()
14601475
raise
14611476

14621477
def _wrap_socket_with_ssl(self, sock):
@@ -1559,7 +1574,7 @@ def repr_pieces(self):
15591574
pieces.append(("client_name", self.client_name))
15601575
return pieces
15611576

1562-
def _connect(self):
1577+
def _connect(self) -> None:
15631578
"Create a Unix domain socket connection"
15641579
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
15651580
sock.settimeout(self.socket_connect_timeout)
@@ -1574,7 +1589,7 @@ def _connect(self):
15741589
sock.close()
15751590
raise
15761591
sock.settimeout(self.socket_timeout)
1577-
return sock
1592+
self._sock = sock
15781593

15791594
def _host_error(self):
15801595
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ async def test_reading_with_load_balancing_strategies(
729729
Connection,
730730
send_command=mock.DEFAULT,
731731
read_response=mock.DEFAULT,
732-
_connect=mock.DEFAULT,
732+
_connect_check_server_ready=mock.DEFAULT,
733733
can_read_destructive=mock.DEFAULT,
734734
on_connect=mock.DEFAULT,
735735
) as mocks:
@@ -761,7 +761,7 @@ def execute_command_mock_third(self, *args, **options):
761761
execute_command.side_effect = execute_command_mock_first
762762
mocks["send_command"].return_value = True
763763
mocks["read_response"].return_value = "OK"
764-
mocks["_connect"].return_value = True
764+
mocks["_connect_check_server_ready"].return_value = True
765765
mocks["can_read_destructive"].return_value = False
766766
mocks["on_connect"].return_value = True
767767

@@ -3117,13 +3117,19 @@ async def execute_command(self, *args, **kwargs):
31173117

31183118
return _create_client
31193119

3120+
@pytest.mark.parametrize("check_server_ready", [True, False])
31203121
async def test_ssl_connection_without_ssl(
3121-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3122+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
31223123
) -> None:
31233124
with pytest.raises(RedisClusterException) as e:
3124-
await create_client(mocked=False, ssl=False)
3125+
await create_client(
3126+
mocked=False, ssl=False, check_server_ready=check_server_ready
3127+
)
31253128
e = e.value.__cause__
3126-
assert "Connection closed by server" in str(e)
3129+
if check_server_ready:
3130+
assert "Invalid PING response" in str(e)
3131+
else:
3132+
assert "Connection closed by server" in str(e)
31273133

31283134
async def test_ssl_with_invalid_cert(
31293135
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)