@@ -575,17 +575,17 @@ def connect_check_health(
575
575
return
576
576
try :
577
577
if retry_socket_connect :
578
- sock = self .retry .call_with_retry (
579
- lambda : self ._connect (), lambda error : self .disconnect (error )
578
+ self .retry .call_with_retry (
579
+ lambda : self ._connect (),
580
+ lambda error : self .disconnect (error ),
580
581
)
581
582
else :
582
- sock = self ._connect ()
583
+ self ._connect ()
583
584
except socket .timeout :
584
585
raise TimeoutError ("Timeout connecting to server" )
585
586
except OSError as e :
586
587
raise ConnectionError (self ._error_message (e ))
587
588
588
- self ._sock = sock
589
589
try :
590
590
if self .redis_connect_func is None :
591
591
# Use the default on_connect function
@@ -608,7 +608,7 @@ def connect_check_health(
608
608
callback (self )
609
609
610
610
@abstractmethod
611
- def _connect (self ):
611
+ def _connect (self ) -> None :
612
612
pass
613
613
614
614
@abstractmethod
@@ -626,6 +626,12 @@ def on_connect_check_health(self, check_health: bool = True):
626
626
self ._parser .on_connect (self )
627
627
parser = self ._parser
628
628
629
+ if check_health :
630
+ self .retry .call_with_retry (
631
+ lambda : self ._send_ping (),
632
+ lambda error : self .disconnect (error ),
633
+ )
634
+
629
635
auth_args = None
630
636
# if credential provider or username and/or password are set, authenticate
631
637
if self .credential_provider or (self .username or self .password ):
@@ -680,7 +686,7 @@ def on_connect_check_health(self, check_health: bool = True):
680
686
# update cluster exception classes
681
687
self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
682
688
self ._parser .on_connect (self )
683
- self .send_command ("HELLO" , self .protocol , check_health = check_health )
689
+ self .send_command ("HELLO" , self .protocol , check_health = False )
684
690
self .handshake_metadata = self .read_response ()
685
691
if (
686
692
self .handshake_metadata .get (b"proto" ) != self .protocol
@@ -711,7 +717,7 @@ def on_connect_check_health(self, check_health: bool = True):
711
717
"ON" ,
712
718
"moving-endpoint-type" ,
713
719
endpoint_type .value ,
714
- check_health = check_health ,
720
+ check_health = False ,
715
721
)
716
722
response = self .read_response ()
717
723
if str_if_bytes (response ) != "OK" :
@@ -737,7 +743,7 @@ def on_connect_check_health(self, check_health: bool = True):
737
743
"CLIENT" ,
738
744
"SETNAME" ,
739
745
self .client_name ,
740
- check_health = check_health ,
746
+ check_health = False ,
741
747
)
742
748
if str_if_bytes (self .read_response ()) != "OK" :
743
749
raise ConnectionError ("Error setting client name" )
@@ -750,7 +756,7 @@ def on_connect_check_health(self, check_health: bool = True):
750
756
"SETINFO" ,
751
757
"LIB-NAME" ,
752
758
self .lib_name ,
753
- check_health = check_health ,
759
+ check_health = False ,
754
760
)
755
761
self .read_response ()
756
762
except ResponseError :
@@ -763,15 +769,15 @@ def on_connect_check_health(self, check_health: bool = True):
763
769
"SETINFO" ,
764
770
"LIB-VER" ,
765
771
self .lib_version ,
766
- check_health = check_health ,
772
+ check_health = False ,
767
773
)
768
774
self .read_response ()
769
775
except ResponseError :
770
776
pass
771
777
772
778
# if a database is specified, switch to it
773
779
if self .db :
774
- self .send_command ("SELECT" , self .db , check_health = check_health )
780
+ self .send_command ("SELECT" , self .db , check_health = False )
775
781
if str_if_bytes (self .read_response ()) != "OK" :
776
782
raise ConnectionError ("Invalid Database" )
777
783
@@ -800,8 +806,16 @@ def disconnect(self, *args):
800
806
def _send_ping (self ):
801
807
"""Send PING, expect PONG in return"""
802
808
self .send_command ("PING" , check_health = False )
803
- if str_if_bytes (self .read_response ()) != "PONG" :
804
- raise ConnectionError ("Bad response from PING health check" )
809
+ try :
810
+ # Do not disconnect on error here, since we want to keep the connection in case of AuthenticationError
811
+ # since we are raising ConnectionError in all other cases and ping_failed already disconnects,
812
+ # connection reload is already handled
813
+ if str_if_bytes (self .read_response (disconnect_on_error = False )) != "PONG" :
814
+ raise ConnectionError ("Bad response from PING health check" )
815
+ except AuthenticationError :
816
+ # if we get an authentication error, the server is healthy
817
+ self ._parser .on_disconnect ()
818
+ self ._parser .on_connect (self )
805
819
806
820
def _ping_failed (self , error ):
807
821
"""Function to call when PING fails"""
@@ -1097,7 +1111,7 @@ def repr_pieces(self):
1097
1111
pieces .append (("client_name" , self .client_name ))
1098
1112
return pieces
1099
1113
1100
- def _connect (self ):
1114
+ def _connect (self ) -> None :
1101
1115
"Create a TCP socket connection"
1102
1116
# we want to mimic what socket.create_connection does to support
1103
1117
# ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1142,8 @@ def _connect(self):
1128
1142
1129
1143
# set the socket_timeout now that we're connected
1130
1144
sock .settimeout (self .socket_timeout )
1131
- return sock
1145
+ self ._sock = sock
1146
+ return
1132
1147
1133
1148
except OSError as _ :
1134
1149
err = _
@@ -1448,15 +1463,15 @@ def __init__(
1448
1463
self .ssl_ciphers = ssl_ciphers
1449
1464
super ().__init__ (** kwargs )
1450
1465
1451
- def _connect (self ):
1466
+ def _connect (self ) -> None :
1452
1467
"""
1453
1468
Wrap the socket with SSL support, handling potential errors.
1454
1469
"""
1455
- sock = super ()._connect ()
1470
+ super ()._connect ()
1456
1471
try :
1457
- return self ._wrap_socket_with_ssl (sock )
1472
+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
1458
1473
except (OSError , RedisError ):
1459
- sock .close ()
1474
+ self . _sock .close ()
1460
1475
raise
1461
1476
1462
1477
def _wrap_socket_with_ssl (self , sock ):
@@ -1559,7 +1574,7 @@ def repr_pieces(self):
1559
1574
pieces .append (("client_name" , self .client_name ))
1560
1575
return pieces
1561
1576
1562
- def _connect (self ):
1577
+ def _connect (self ) -> None :
1563
1578
"Create a Unix domain socket connection"
1564
1579
sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
1565
1580
sock .settimeout (self .socket_connect_timeout )
@@ -1574,7 +1589,7 @@ def _connect(self):
1574
1589
sock .close ()
1575
1590
raise
1576
1591
sock .settimeout (self .socket_timeout )
1577
- return sock
1592
+ self . _sock = sock
1578
1593
1579
1594
def _host_error (self ):
1580
1595
return self .path
0 commit comments