Skip to content

Commit 5218c48

Browse files
Add Redis readiness verification (#3555)
1 parent b102c3b commit 5218c48

File tree

11 files changed

+476
-127
lines changed

11 files changed

+476
-127
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __init__(
230230
encoding: str = "utf-8",
231231
encoding_errors: str = "strict",
232232
decode_responses: bool = False,
233+
check_server_ready: bool = False,
233234
retry_on_timeout: bool = False,
234235
retry: Retry = Retry(
235236
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -279,6 +280,10 @@ def __init__(
279280
280281
When 'connection_pool' is provided - the retry configuration of the
281282
provided pool will be used.
283+
284+
Args:
285+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
286+
connect and send operations work even when Redis server is not ready.
282287
"""
283288
kwargs: Dict[str, Any]
284289
if event_dispatcher is None:
@@ -313,6 +318,7 @@ def __init__(
313318
"encoding": encoding,
314319
"encoding_errors": encoding_errors,
315320
"decode_responses": decode_responses,
321+
"check_server_ready": check_server_ready,
316322
"retry_on_error": retry_on_error,
317323
"retry": copy.deepcopy(retry),
318324
"max_connections": max_connections,

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def __init__(
290290
encoding_errors: str = "strict",
291291
decode_responses: bool = False,
292292
# Connection related kwargs
293+
check_server_ready: bool = False,
293294
health_check_interval: float = 0,
294295
socket_connect_timeout: Optional[float] = None,
295296
socket_keepalive: bool = False,
@@ -345,6 +346,7 @@ def __init__(
345346
"encoding_errors": encoding_errors,
346347
"decode_responses": decode_responses,
347348
# Connection related kwargs
349+
"check_server_ready": check_server_ready,
348350
"health_check_interval": health_check_interval,
349351
"socket_connect_timeout": socket_connect_timeout,
350352
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
encoding_errors: str = "strict",
150150
decode_responses: bool = False,
151151
parser_class: Type[BaseParser] = DefaultParser,
152+
check_server_ready: bool = False,
152153
socket_read_size: int = 65536,
153154
health_check_interval: float = 0,
154155
client_name: Optional[str] = None,
@@ -205,6 +206,7 @@ def __init__(
205206
self.health_check_interval = health_check_interval
206207
self.next_health_check: float = -1
207208
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
209+
self.check_server_ready = check_server_ready
208210
self.redis_connect_func = redis_connect_func
209211
self._reader: Optional[asyncio.StreamReader] = None
210212
self._writer: Optional[asyncio.StreamWriter] = None
@@ -304,11 +306,13 @@ async def connect_check_health(
304306
try:
305307
if retry_socket_connect:
306308
await self.retry.call_with_retry(
307-
lambda: self._connect(), lambda error: self.disconnect()
309+
lambda: self._connect_check_server_ready(),
310+
lambda error: self.disconnect(),
308311
)
309312
else:
310-
await self._connect()
313+
await self._connect_check_server_ready()
311314
except asyncio.CancelledError:
315+
self._close()
312316
raise # in 3.7 and earlier, this is an Exception, not BaseException
313317
except (socket.timeout, asyncio.TimeoutError):
314318
raise TimeoutError("Timeout connecting to server")
@@ -343,6 +347,33 @@ async def connect_check_health(
343347
if task and inspect.isawaitable(task):
344348
await task
345349

350+
async def _connect_check_server_ready(self):
351+
await self._connect()
352+
353+
# Doing handshake since connect and send operations work even when Redis is not ready
354+
if self.check_server_ready:
355+
try:
356+
await self.send_command("PING", check_health=False)
357+
358+
if self.socket_timeout is not None:
359+
async with async_timeout(self.socket_timeout):
360+
response = str_if_bytes(await self._reader.read(1024))
361+
else:
362+
response = str_if_bytes(await self._reader.read(1024))
363+
364+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
365+
raise ResponseError(f"Invalid PING response: {response}")
366+
except (
367+
socket.timeout,
368+
asyncio.TimeoutError,
369+
ResponseError,
370+
ConnectionResetError,
371+
) as e:
372+
# `socket_keepalive_options` might contain invalid options
373+
# causing an error. Do not leave the connection open.
374+
self._close()
375+
raise ConnectionError(self._error_message(e))
376+
346377
@abstractmethod
347378
async def _connect(self):
348379
pass
@@ -532,8 +563,7 @@ async def send_packed_command(
532563
self._send_packed_command(command), self.socket_timeout
533564
)
534565
else:
535-
self._writer.writelines(command)
536-
await self._writer.drain()
566+
await self._send_packed_command(command)
537567
except asyncio.TimeoutError:
538568
await self.disconnect(nowait=True)
539569
raise TimeoutError("Timeout writing to socket") from None
@@ -776,7 +806,7 @@ async def _connect(self):
776806
except (OSError, TypeError):
777807
# `socket_keepalive_options` might contain invalid options
778808
# causing an error. Do not leave the connection open.
779-
writer.close()
809+
self._close()
780810
raise
781811

782812
def _host_error(self) -> str:
@@ -961,7 +991,6 @@ async def _connect(self):
961991
reader, writer = await asyncio.open_unix_connection(path=self.path)
962992
self._reader = reader
963993
self._writer = writer
964-
await self.on_connect()
965994

966995
def _host_error(self) -> str:
967996
return self.path

redis/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215
encoding: str = "utf-8",
216216
encoding_errors: str = "strict",
217217
decode_responses: bool = False,
218+
check_server_ready: bool = False,
218219
retry_on_timeout: bool = False,
219220
retry: Retry = Retry(
220221
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -274,10 +275,11 @@ def __init__(
274275
provided pool will be used.
275276
276277
Args:
277-
278-
single_connection_client:
279-
if `True`, connection pool is not used. In that case `Redis`
280-
instance use is not thread safe.
278+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
279+
connect and send operations work even when Redis server is not ready.
280+
single_connection_client:
281+
if `True`, connection pool is not used. In that case `Redis`
282+
instance use is not thread safe.
281283
"""
282284
if event_dispatcher is None:
283285
self._event_dispatcher = EventDispatcher()
@@ -294,6 +296,7 @@ def __init__(
294296
"encoding": encoding,
295297
"encoding_errors": encoding_errors,
296298
"decode_responses": decode_responses,
299+
"check_server_ready": check_server_ready,
297300
"retry_on_error": retry_on_error,
298301
"retry": copy.deepcopy(retry),
299302
"max_connections": max_connections,

redis/connection.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
encoding: str = "utf-8",
335335
encoding_errors: str = "strict",
336336
decode_responses: bool = False,
337+
check_server_ready: bool = False,
337338
parser_class=DefaultParser,
338339
socket_read_size: int = 65536,
339340
health_check_interval: int = 0,
@@ -412,6 +413,7 @@ def __init__(
412413
self.redis_connect_func = redis_connect_func
413414
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
414415
self.handshake_metadata = None
416+
self.check_server_ready = check_server_ready
415417
self._sock = None
416418
self._socket_read_size = socket_read_size
417419
self._connect_callbacks = []
@@ -575,17 +577,17 @@ def connect_check_health(
575577
return
576578
try:
577579
if retry_socket_connect:
578-
sock = self.retry.call_with_retry(
579-
lambda: self._connect(), lambda error: self.disconnect(error)
580+
self.retry.call_with_retry(
581+
lambda: self._connect_check_server_ready(),
582+
lambda error: self.disconnect(error),
580583
)
581584
else:
582-
sock = self._connect()
585+
self._connect_check_server_ready()
583586
except socket.timeout:
584587
raise TimeoutError("Timeout connecting to server")
585588
except OSError as e:
586589
raise ConnectionError(self._error_message(e))
587590

588-
self._sock = sock
589591
try:
590592
if self.redis_connect_func is None:
591593
# Use the default on_connect function
@@ -607,8 +609,27 @@ def connect_check_health(
607609
if callback:
608610
callback(self)
609611

612+
def _connect_check_server_ready(self):
613+
self._connect()
614+
615+
# Doing handshake since connect and send operations work even when Redis is not ready
616+
if self.check_server_ready:
617+
try:
618+
self.send_command("PING", check_health=False)
619+
620+
response = str_if_bytes(self._sock.recv(1024))
621+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
622+
raise ResponseError(f"Invalid PING response: {response}")
623+
except (ConnectionResetError, ResponseError) as err:
624+
try:
625+
self._sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
626+
except OSError:
627+
pass
628+
self._sock.close()
629+
raise ConnectionError(self._error_message(err))
630+
610631
@abstractmethod
611-
def _connect(self):
632+
def _connect(self) -> None:
612633
pass
613634

614635
@abstractmethod
@@ -1097,7 +1118,7 @@ def repr_pieces(self):
10971118
pieces.append(("client_name", self.client_name))
10981119
return pieces
10991120

1100-
def _connect(self):
1121+
def _connect(self) -> None:
11011122
"Create a TCP socket connection"
11021123
# we want to mimic what socket.create_connection does to support
11031124
# ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1149,8 @@ def _connect(self):
11281149

11291150
# set the socket_timeout now that we're connected
11301151
sock.settimeout(self.socket_timeout)
1131-
return sock
1152+
self._sock = sock
1153+
return
11321154

11331155
except OSError as _:
11341156
err = _
@@ -1448,15 +1470,15 @@ def __init__(
14481470
self.ssl_ciphers = ssl_ciphers
14491471
super().__init__(**kwargs)
14501472

1451-
def _connect(self):
1473+
def _connect(self) -> None:
14521474
"""
14531475
Wrap the socket with SSL support, handling potential errors.
14541476
"""
1455-
sock = super()._connect()
1477+
super()._connect()
14561478
try:
1457-
return self._wrap_socket_with_ssl(sock)
1479+
self._sock = self._wrap_socket_with_ssl(self._sock)
14581480
except (OSError, RedisError):
1459-
sock.close()
1481+
self._sock.close()
14601482
raise
14611483

14621484
def _wrap_socket_with_ssl(self, sock):
@@ -1559,7 +1581,7 @@ def repr_pieces(self):
15591581
pieces.append(("client_name", self.client_name))
15601582
return pieces
15611583

1562-
def _connect(self):
1584+
def _connect(self) -> None:
15631585
"Create a Unix domain socket connection"
15641586
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
15651587
sock.settimeout(self.socket_connect_timeout)
@@ -1574,7 +1596,7 @@ def _connect(self):
15741596
sock.close()
15751597
raise
15761598
sock.settimeout(self.socket_timeout)
1577-
return sock
1599+
self._sock = sock
15781600

15791601
def _host_error(self):
15801602
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ async def test_reading_with_load_balancing_strategies(
729729
Connection,
730730
send_command=mock.DEFAULT,
731731
read_response=mock.DEFAULT,
732-
_connect=mock.DEFAULT,
732+
_connect_check_server_ready=mock.DEFAULT,
733733
can_read_destructive=mock.DEFAULT,
734734
on_connect=mock.DEFAULT,
735735
) as mocks:
@@ -761,7 +761,7 @@ def execute_command_mock_third(self, *args, **options):
761761
execute_command.side_effect = execute_command_mock_first
762762
mocks["send_command"].return_value = True
763763
mocks["read_response"].return_value = "OK"
764-
mocks["_connect"].return_value = True
764+
mocks["_connect_check_server_ready"].return_value = True
765765
mocks["can_read_destructive"].return_value = False
766766
mocks["on_connect"].return_value = True
767767

@@ -3117,13 +3117,19 @@ async def execute_command(self, *args, **kwargs):
31173117

31183118
return _create_client
31193119

3120+
@pytest.mark.parametrize("check_server_ready", [True, False])
31203121
async def test_ssl_connection_without_ssl(
3121-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3122+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
31223123
) -> None:
31233124
with pytest.raises(RedisClusterException) as e:
3124-
await create_client(mocked=False, ssl=False)
3125+
await create_client(
3126+
mocked=False, ssl=False, check_server_ready=check_server_ready
3127+
)
31253128
e = e.value.__cause__
3126-
assert "Connection closed by server" in str(e)
3129+
if check_server_ready:
3130+
assert "Invalid PING response" in str(e)
3131+
else:
3132+
assert "Connection closed by server" in str(e)
31273133

31283134
async def test_ssl_with_invalid_cert(
31293135
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)