@@ -334,6 +334,7 @@ def __init__(
334
334
encoding : str = "utf-8" ,
335
335
encoding_errors : str = "strict" ,
336
336
decode_responses : bool = False ,
337
+ check_server_ready : bool = False ,
337
338
parser_class = DefaultParser ,
338
339
socket_read_size : int = 65536 ,
339
340
health_check_interval : int = 0 ,
@@ -412,6 +413,7 @@ def __init__(
412
413
self .redis_connect_func = redis_connect_func
413
414
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
414
415
self .handshake_metadata = None
416
+ self .check_server_ready = check_server_ready
415
417
self ._sock = None
416
418
self ._socket_read_size = socket_read_size
417
419
self ._connect_callbacks = []
@@ -575,17 +577,17 @@ def connect_check_health(
575
577
return
576
578
try :
577
579
if retry_socket_connect :
578
- sock = self .retry .call_with_retry (
579
- lambda : self ._connect (), lambda error : self .disconnect (error )
580
+ self .retry .call_with_retry (
581
+ lambda : self ._connect_check_server_ready (),
582
+ lambda error : self .disconnect (error ),
580
583
)
581
584
else :
582
- sock = self ._connect ()
585
+ self ._connect_check_server_ready ()
583
586
except socket .timeout :
584
587
raise TimeoutError ("Timeout connecting to server" )
585
588
except OSError as e :
586
589
raise ConnectionError (self ._error_message (e ))
587
590
588
- self ._sock = sock
589
591
try :
590
592
if self .redis_connect_func is None :
591
593
# Use the default on_connect function
@@ -607,8 +609,27 @@ def connect_check_health(
607
609
if callback :
608
610
callback (self )
609
611
612
+ def _connect_check_server_ready (self ):
613
+ self ._connect ()
614
+
615
+ # Doing handshake since connect and send operations work even when Redis is not ready
616
+ if self .check_server_ready :
617
+ try :
618
+ self .send_command ("PING" , check_health = False )
619
+
620
+ response = str_if_bytes (self ._sock .recv (1024 ))
621
+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
622
+ raise ResponseError (f"Invalid PING response: { response } " )
623
+ except (ConnectionResetError , ResponseError ) as err :
624
+ try :
625
+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
626
+ except OSError :
627
+ pass
628
+ self ._sock .close ()
629
+ raise ConnectionError (self ._error_message (err ))
630
+
610
631
@abstractmethod
611
- def _connect (self ):
632
+ def _connect (self ) -> None :
612
633
pass
613
634
614
635
@abstractmethod
@@ -1097,7 +1118,7 @@ def repr_pieces(self):
1097
1118
pieces .append (("client_name" , self .client_name ))
1098
1119
return pieces
1099
1120
1100
- def _connect (self ):
1121
+ def _connect (self ) -> None :
1101
1122
"Create a TCP socket connection"
1102
1123
# we want to mimic what socket.create_connection does to support
1103
1124
# ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1149,8 @@ def _connect(self):
1128
1149
1129
1150
# set the socket_timeout now that we're connected
1130
1151
sock .settimeout (self .socket_timeout )
1131
- return sock
1152
+ self ._sock = sock
1153
+ return
1132
1154
1133
1155
except OSError as _ :
1134
1156
err = _
@@ -1448,15 +1470,15 @@ def __init__(
1448
1470
self .ssl_ciphers = ssl_ciphers
1449
1471
super ().__init__ (** kwargs )
1450
1472
1451
- def _connect (self ):
1473
+ def _connect (self ) -> None :
1452
1474
"""
1453
1475
Wrap the socket with SSL support, handling potential errors.
1454
1476
"""
1455
- sock = super ()._connect ()
1477
+ super ()._connect ()
1456
1478
try :
1457
- return self ._wrap_socket_with_ssl (sock )
1479
+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
1458
1480
except (OSError , RedisError ):
1459
- sock .close ()
1481
+ self . _sock .close ()
1460
1482
raise
1461
1483
1462
1484
def _wrap_socket_with_ssl (self , sock ):
@@ -1559,7 +1581,7 @@ def repr_pieces(self):
1559
1581
pieces .append (("client_name" , self .client_name ))
1560
1582
return pieces
1561
1583
1562
- def _connect (self ):
1584
+ def _connect (self ) -> None :
1563
1585
"Create a Unix domain socket connection"
1564
1586
sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
1565
1587
sock .settimeout (self .socket_connect_timeout )
@@ -1574,7 +1596,7 @@ def _connect(self):
1574
1596
sock .close ()
1575
1597
raise
1576
1598
sock .settimeout (self .socket_timeout )
1577
- return sock
1599
+ self . _sock = sock
1578
1600
1579
1601
def _host_error (self ):
1580
1602
return self .path
0 commit comments