Skip to content

Hitless handshake #3735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: ps_add_fail_over_events_handling
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 72 additions & 39 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
Expand Down Expand Up @@ -56,6 +57,8 @@
"Client sent AUTH, but no password is set": AuthenticationError,
}

logger = logging.getLogger(__name__)


class BaseParser(ABC):
EXCEPTION_CLASSES = {
Expand Down Expand Up @@ -199,28 +202,41 @@ def handle_push_response(self, response, **kwargs):
*_MOVING_MESSAGE,
):
return self.pubsub_push_handler_func(response)
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
host, port = response[2].decode().split(":")
ttl = response[1]
id = 1 # Hardcoded value until the notification starts including the id
notification = NodeMovingEvent(id, host, port, ttl)
return self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
if msg_type in _MIGRATING_MESSAGE:
ttl = response[1]
id = 2 # Hardcoded value until the notification starts including the id
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = 3 # Hardcoded value until the notification starts including the id
notification = NodeMigratedEvent(id)
else:

try:
if (
msg_type in _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
# Expected message format is: MOVING <seq_number> <time> <endpoint>
id = response[1]
ttl = response[2]
host, port = response[3].decode().split(":")
notification = NodeMovingEvent(id, host, port, ttl)
return self.node_moving_push_handler_func(notification)

if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
notification = None
if notification is not None:
return self.maintenance_push_handler_func(notification)
else:
return None

if msg_type in _MIGRATING_MESSAGE:
# Expected message format is: MIGRATING <seq_number> <time> <shard_id-s>
id = response[1]
ttl = response[2]
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = response[1]
notification = NodeMigratedEvent(id)

if notification is not None:
return self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)

return None

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
Expand Down Expand Up @@ -249,31 +265,48 @@ async def handle_pubsub_push_response(self, response):

async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""

msg_type = response[0]
if msg_type not in (
*_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
*_MOVING_MESSAGE,
):
return await self.pubsub_push_handler_func(response)
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
# push notification from enterprise cluster for node moving
host, port = response[2].split(":")
ttl = response[1]
id = 1 # Hardcoded value for async parser
notification = NodeMovingEvent(id, host, port, ttl)
return await self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
if msg_type in _MIGRATING_MESSAGE:
ttl = response[1]
id = 2 # Hardcoded value for async parser
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = 3 # Hardcoded value for async parser
notification = NodeMigratedEvent(id)
return await self.maintenance_push_handler_func(notification)

try:
if (
msg_type in _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return await self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
# push notification from enterprise cluster for node moving
id = response[1]
ttl = response[2]
host, port = response[3].split(":")
notification = NodeMovingEvent(id, host, port, ttl)
return await self.node_moving_push_handler_func(notification)

if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
notification = None

if msg_type in _MIGRATING_MESSAGE:
id = response[1]
ttl = response[2]
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = response[1]
notification = NodeMigratedEvent(id)

if notification is not None:
return self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)

return None

def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
Expand Down
96 changes: 75 additions & 21 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,28 +400,15 @@ def __init__(
parser_class = _RESP3Parser
self.set_parser(parser_class)

if maintenance_events_config and maintenance_events_config.enabled:
if maintenance_events_pool_handler:
self._parser.set_node_moving_push_handler(
maintenance_events_pool_handler.handle_event
)
self._maintenance_event_connection_handler = (
MaintenanceEventConnectionHandler(self, maintenance_events_config)
)
self._parser.set_maintenance_push_handler(
self._maintenance_event_connection_handler.handle_event
)
self.maintenance_events_config = maintenance_events_config

self.orig_host_address = (
orig_host_address if orig_host_address else self.host
)
self.orig_socket_timeout = (
orig_socket_timeout if orig_socket_timeout else self.socket_timeout
)
self.orig_socket_connect_timeout = (
orig_socket_connect_timeout
if orig_socket_connect_timeout
else self.socket_connect_timeout
# Set up maintenance events if enabled
if maintenance_events_config and maintenance_events_config.enabled:
self._enable_maintenance_events(
maintenance_events_pool_handler,
orig_host_address,
orig_socket_timeout,
orig_socket_connect_timeout,
)
self._should_reconnect = False
self.maintenance_state = maintenance_state
Expand Down Expand Up @@ -481,6 +468,42 @@ def set_parser(self, parser_class):
"""
self._parser = parser_class(socket_read_size=self._socket_read_size)

def _enable_maintenance_events(
self,
maintenance_events_pool_handler=None,
orig_host_address=None,
orig_socket_timeout=None,
orig_socket_connect_timeout=None,
):
"""Enable maintenance events by setting up handlers and storing original connection parameters."""
if not self.maintenance_events_config:
return

# Set up pool handler if available
if maintenance_events_pool_handler:
self._parser.set_node_moving_push_handler(
maintenance_events_pool_handler.handle_event
)

# Set up connection handler
self._maintenance_event_connection_handler = MaintenanceEventConnectionHandler(
self, self.maintenance_events_config
)
self._parser.set_maintenance_push_handler(
self._maintenance_event_connection_handler.handle_event
)

# Store original connection parameters
self.orig_host_address = orig_host_address if orig_host_address else self.host
self.orig_socket_timeout = (
orig_socket_timeout if orig_socket_timeout else self.socket_timeout
)
self.orig_socket_connect_timeout = (
orig_socket_connect_timeout
if orig_socket_connect_timeout
else self.socket_connect_timeout
)

def set_maintenance_event_pool_handler(
self, maintenance_event_pool_handler: MaintenanceEventPoolHandler
):
Expand Down Expand Up @@ -623,6 +646,37 @@ def on_connect_check_health(self, check_health: bool = True):
):
raise ConnectionError("Invalid RESP version")

# Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
if (
self.protocol not in [2, "2"]
and self.maintenance_events_config
and self.maintenance_events_config.enabled
and hasattr(self, "_maintenance_event_connection_handler")
):
try:
endpoint_type = self.maintenance_events_config.get_endpoint_type(
self.host, self
)
self.send_command(
"CLIENT",
"MAINT_NOTIFICATIONS",
"ON",
"moving-endpoint-type",
endpoint_type,
check_health=check_health,
)
response = self.read_response()
if str_if_bytes(response) != "OK":
raise ConnectionError(
"The server doesn't support maintenance notifications"
)
except Exception as e:
# Log warning but don't fail the connection
import logging

logger = logging.getLogger(__name__)
logger.warning(f"Failed to enable maintenance notifications: {e}")

# if a client_name is given, set it
if self.client_name:
self.send_command(
Expand Down
Loading