From d2288e991820b72e1cbfbd950ec5a405d56f8ae8 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Fri, 1 Aug 2025 11:03:59 +0100 Subject: [PATCH 1/8] Adds handshake for enabling server maintenance notifications Signed-off-by: Elena Kolevska Cleanup Signed-off-by: Elena Kolevska --- redis/maintenance_events.py | 170 ++++++++++++++++++++++++++++++++++++ test_endpoint_type.py | 123 ++++++++++++++++++++++++++ 2 files changed, 293 insertions(+) create mode 100644 test_endpoint_type.py diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index f99ad37397..6eebc00dc6 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -1,5 +1,7 @@ import enum +import ipaddress import logging +import re import threading import time from abc import ABC, abstractmethod @@ -15,6 +17,20 @@ class MaintenanceState(enum.Enum): FAILING_OVER = "failing_over" +class EndpointType: + """Constants for valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command.""" + INTERNAL_IP = "internal-ip" + INTERNAL_FQDN = "internal-fqdn" + EXTERNAL_IP = "external-ip" + EXTERNAL_FQDN = "external-fqdn" + NONE = "none" + + @classmethod + def get_valid_types(cls): + """Return a set of all valid endpoint types.""" + return {cls.INTERNAL_IP, cls.INTERNAL_FQDN, cls.EXTERNAL_IP, cls.EXTERNAL_FQDN, cls.NONE} + + if TYPE_CHECKING: from redis.connection import ( BlockingConnectionPool, @@ -361,6 +377,96 @@ def __hash__(self) -> int: return hash((self.__class__, self.id)) +def _is_private_fqdn(host: str) -> bool: + """ + Determine if an FQDN is likely to be internal/private. + + This uses heuristics based on RFC 952 and RFC 1123 standards: + - .local domains (RFC 6762 - Multicast DNS) + - .internal domains (common internal convention) + - Single-label hostnames (no dots) + - Common internal TLDs + + Args: + host (str): The FQDN to check + + Returns: + bool: True if the FQDN appears to be internal/private + """ + host_lower = host.lower().rstrip('.') + + # Single-label hostnames (no dots) are typically internal + if '.' not in host_lower: + return True + + # Common internal/private domain patterns + internal_patterns = [ + r'\.local$', # mDNS/Bonjour domains + r'\.internal$', # Common internal convention + r'\.corp$', # Corporate domains + r'\.lan$', # Local area network + r'\.intranet$', # Intranet domains + r'\.private$', # Private domains + ] + + for pattern in internal_patterns: + if re.search(pattern, host_lower): + return True + + # If none of the internal patterns match, assume it's external + return False + + +def _get_resolved_ip_from_connection(connection: "ConnectionInterface") -> Optional[str]: + """ + Extract the resolved IP address from an established connection. + + First tries to get the actual IP from the socket (most accurate), + then falls back to DNS resolution if needed. + + Args: + connection: The connection object to extract the IP from + + Returns: + str: The resolved IP address, or None if it cannot be determined + """ + import socket + + # Method 1: Try to get the actual IP from the established socket connection + # This is most accurate as it shows the exact IP being used + try: + sock = getattr(connection, '_sock', None) + if sock is not None: + peer_addr = sock.getpeername() + if peer_addr and len(peer_addr) >= 1: + # For TCP sockets, peer_addr is typically (host, port) tuple + # Return just the host part + return peer_addr[0] + except (AttributeError, OSError): + # Socket might not be connected or getpeername() might fail + pass + + # Method 2: Fallback to DNS resolution of the host + # This is less accurate but works when socket is not available + try: + host = getattr(connection, 'host', None) + port = getattr(connection, 'port', 6379) + if host: + # Use getaddrinfo to resolve the hostname to IP + # This mimics what the connection would do during _connect() + addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM) + if addr_info: + # Return the IP from the first result + # addr_info[0] is (family, socktype, proto, canonname, sockaddr) + # sockaddr[0] is the IP address + return addr_info[0][4][0] + except (AttributeError, OSError, socket.gaierror): + # DNS resolution might fail + pass + + return None + + class MaintenanceEventsConfig: """ Configuration class for maintenance events handling behaviour. Events are received through @@ -376,6 +482,7 @@ def __init__( enabled: bool = False, proactive_reconnect: bool = True, relax_timeout: Optional[Number] = 20, + endpoint_type: Optional[str] = None, ): """ Initialize a new MaintenanceEventsConfig. @@ -387,18 +494,32 @@ def __init__( Defaults to True. relax_timeout (Number): The relax timeout to use for the connection during maintenance. If -1 is provided - the relax timeout is disabled. Defaults to 20. + endpoint_type (Optional[str]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS. + Must be one of: 'internal-ip', 'internal-fqdn', 'external-ip', 'external-fqdn', 'none'. + If None, the endpoint type will be automatically determined based on the host and TLS configuration. + Defaults to None. + Raises: + ValueError: If endpoint_type is provided but is not a valid endpoint type. """ self.enabled = enabled self.relax_timeout = relax_timeout self.proactive_reconnect = proactive_reconnect + # Validate endpoint_type if provided + if endpoint_type is not None and endpoint_type not in EndpointType.get_valid_types(): + valid_types = ', '.join(f"'{t}'" for t in sorted(EndpointType.get_valid_types())) + raise ValueError(f"Invalid endpoint_type '{endpoint_type}'. Must be one of: {valid_types}") + + self.endpoint_type = endpoint_type + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"enabled={self.enabled}, " f"proactive_reconnect={self.proactive_reconnect}, " f"relax_timeout={self.relax_timeout}, " + f"endpoint_type={self.endpoint_type!r}" f")" ) @@ -413,6 +534,53 @@ def is_relax_timeouts_enabled(self) -> bool: """ return self.relax_timeout != -1 + def get_endpoint_type(self, host: str, connection: "ConnectionInterface") -> str: + """ + Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. + + Logic: + 1. If endpoint_type is explicitly set, use it + 3. Otherwise, check the original host from host: + - If host is an IP address, use it directly to determine internal-ip vs external-ip + - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn + + Args: + host: User provided hostname to analyze + connection: The connection object to analyze for endpoint type determination + + Returns: + """ + + # If endpoint_type is explicitly set, use it + if self.endpoint_type is not None: + return self.endpoint_type + + # Check if the host is an IP address + try: + ip_addr = ipaddress.ip_address(host) + # Host is an IP address - use it directly + is_private = ip_addr.is_private + return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP + except ValueError: + # Host is an FQDN - need to check resolved IP to determine internal vs external + pass + + # Host is an FQDN, get the resolved IP to determine if it's internal or external + resolved_ip = _get_resolved_ip_from_connection(connection) + + if resolved_ip: + try: + ip_addr = ipaddress.ip_address(resolved_ip) + is_private = ip_addr.is_private + # Use FQDN types since the original host was an FQDN + return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN + except ValueError: + # This shouldn't happen since we got the IP from the socket, but fallback + pass + + # Final fallback: use heuristics on the FQDN itself + is_private = _is_private_fqdn(host) + return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN class MaintenanceEventPoolHandler: def __init__( @@ -440,11 +608,13 @@ def handle_event(self, notification: MaintenanceEvent): logging.error(f"Unhandled notification type: {notification}") def handle_node_moving_event(self, event: NodeMovingEvent): + print("Received MOVING event: {event}") if ( not self.config.proactive_reconnect and not self.config.is_relax_timeouts_enabled() ): return + print("Handling MOVING event: {event}") with self._lock: if event in self._processed_events: # nothing to do in the connection pool handling diff --git a/test_endpoint_type.py b/test_endpoint_type.py new file mode 100644 index 0000000000..70b6c6c169 --- /dev/null +++ b/test_endpoint_type.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" +Quick test script to verify the MaintenanceEventsConfig endpoint type functionality. +""" + +import sys +import os + +# Add the redis module to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + +from redis.maintenance_events import MaintenanceEventsConfig, EndpointType + +def test_endpoint_type_constants(): + """Test that the EndpointType constants are correct.""" + print("Testing EndpointType constants...") + + assert EndpointType.INTERNAL_IP == "internal-ip" + assert EndpointType.INTERNAL_FQDN == "internal-fqdn" + assert EndpointType.EXTERNAL_IP == "external-ip" + assert EndpointType.EXTERNAL_FQDN == "external-fqdn" + assert EndpointType.NONE == "none" + + valid_types = EndpointType.get_valid_types() + expected_types = {"internal-ip", "internal-fqdn", "external-ip", "external-fqdn", "none"} + assert valid_types == expected_types + + print("āœ“ EndpointType constants are correct") + +def test_config_validation(): + """Test that MaintenanceEventsConfig validates endpoint_type correctly.""" + print("Testing MaintenanceEventsConfig validation...") + + # Valid endpoint types should work + for endpoint_type in EndpointType.get_valid_types(): + config = MaintenanceEventsConfig(endpoint_type=endpoint_type) + assert config.endpoint_type == endpoint_type + + # Invalid endpoint type should raise ValueError + try: + MaintenanceEventsConfig(endpoint_type="invalid-type") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Invalid endpoint_type" in str(e) + + # None should be allowed + config = MaintenanceEventsConfig(endpoint_type=None) + assert config.endpoint_type is None + + print("āœ“ MaintenanceEventsConfig validation works correctly") + +def test_endpoint_type_detection(): + """Test the get_endpoint_type method with various inputs.""" + print("Testing endpoint type detection...") + + config = MaintenanceEventsConfig() + + # Test IPv4 addresses + assert config.get_endpoint_type("192.168.1.1") == EndpointType.INTERNAL_IP # Private IPv4 + assert config.get_endpoint_type("10.0.0.1") == EndpointType.INTERNAL_IP # Private IPv4 + assert config.get_endpoint_type("172.16.0.1") == EndpointType.INTERNAL_IP # Private IPv4 + assert config.get_endpoint_type("8.8.8.8") == EndpointType.EXTERNAL_IP # Public IPv4 + assert config.get_endpoint_type("1.1.1.1") == EndpointType.EXTERNAL_IP # Public IPv4 + + # Test IPv6 addresses + result1 = config.get_endpoint_type("::1") + print(f"::1 -> {result1} (expected: {EndpointType.INTERNAL_IP})") + assert result1 == EndpointType.INTERNAL_IP # Loopback IPv6 + + result2 = config.get_endpoint_type("fd00::1") + print(f"fd00::1 -> {result2} (expected: {EndpointType.INTERNAL_IP})") + assert result2 == EndpointType.INTERNAL_IP # Private IPv6 + + result3 = config.get_endpoint_type("2001:db8::1") + print(f"2001:db8::1 -> {result3} (expected: {EndpointType.EXTERNAL_IP})") + assert result3 == EndpointType.EXTERNAL_IP # Public IPv6 + + # Test FQDNs without TLS + assert config.get_endpoint_type("localhost") == EndpointType.INTERNAL_FQDN # Single label + assert config.get_endpoint_type("server.local") == EndpointType.INTERNAL_FQDN # .local domain + assert config.get_endpoint_type("app.internal") == EndpointType.INTERNAL_FQDN # .internal domain + assert config.get_endpoint_type("example.com") == EndpointType.EXTERNAL_FQDN # Public domain + + # Test FQDNs with TLS + assert config.get_endpoint_type("server.local", tls_enabled=True) == EndpointType.INTERNAL_FQDN + assert config.get_endpoint_type("example.com", tls_enabled=True) == EndpointType.EXTERNAL_FQDN + + print("āœ“ Endpoint type detection works correctly") + +def test_override_behavior(): + """Test that explicit endpoint_type overrides automatic detection.""" + print("Testing override behavior...") + + config = MaintenanceEventsConfig(endpoint_type=EndpointType.NONE) + + # Should always return the override value regardless of host + assert config.get_endpoint_type("192.168.1.1") == EndpointType.NONE + assert config.get_endpoint_type("8.8.8.8") == EndpointType.NONE + assert config.get_endpoint_type("example.com") == EndpointType.NONE + assert config.get_endpoint_type("localhost") == EndpointType.NONE + + print("āœ“ Override behavior works correctly") + +def main(): + """Run all tests.""" + print("Running MaintenanceEventsConfig endpoint type tests...\n") + + try: + test_endpoint_type_constants() + test_config_validation() + test_endpoint_type_detection() + test_override_behavior() + + print("\nšŸŽ‰ All tests passed!") + return 0 + except Exception as e: + print(f"\nāŒ Test failed: {e}") + import traceback + traceback.print_exc() + return 1 + +if __name__ == "__main__": + sys.exit(main()) From b294db2f395a1db019127dc6c571e2153b01966c Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Thu, 7 Aug 2025 15:40:06 +0300 Subject: [PATCH 2/8] Adds tests Signed-off-by: Elena Kolevska --- test_endpoint_type.py | 123 ------------------------- tests/test_maintenance_events.py | 151 +++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 123 deletions(-) delete mode 100644 test_endpoint_type.py diff --git a/test_endpoint_type.py b/test_endpoint_type.py deleted file mode 100644 index 70b6c6c169..0000000000 --- a/test_endpoint_type.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python3 -""" -Quick test script to verify the MaintenanceEventsConfig endpoint type functionality. -""" - -import sys -import os - -# Add the redis module to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__))) - -from redis.maintenance_events import MaintenanceEventsConfig, EndpointType - -def test_endpoint_type_constants(): - """Test that the EndpointType constants are correct.""" - print("Testing EndpointType constants...") - - assert EndpointType.INTERNAL_IP == "internal-ip" - assert EndpointType.INTERNAL_FQDN == "internal-fqdn" - assert EndpointType.EXTERNAL_IP == "external-ip" - assert EndpointType.EXTERNAL_FQDN == "external-fqdn" - assert EndpointType.NONE == "none" - - valid_types = EndpointType.get_valid_types() - expected_types = {"internal-ip", "internal-fqdn", "external-ip", "external-fqdn", "none"} - assert valid_types == expected_types - - print("āœ“ EndpointType constants are correct") - -def test_config_validation(): - """Test that MaintenanceEventsConfig validates endpoint_type correctly.""" - print("Testing MaintenanceEventsConfig validation...") - - # Valid endpoint types should work - for endpoint_type in EndpointType.get_valid_types(): - config = MaintenanceEventsConfig(endpoint_type=endpoint_type) - assert config.endpoint_type == endpoint_type - - # Invalid endpoint type should raise ValueError - try: - MaintenanceEventsConfig(endpoint_type="invalid-type") - assert False, "Should have raised ValueError" - except ValueError as e: - assert "Invalid endpoint_type" in str(e) - - # None should be allowed - config = MaintenanceEventsConfig(endpoint_type=None) - assert config.endpoint_type is None - - print("āœ“ MaintenanceEventsConfig validation works correctly") - -def test_endpoint_type_detection(): - """Test the get_endpoint_type method with various inputs.""" - print("Testing endpoint type detection...") - - config = MaintenanceEventsConfig() - - # Test IPv4 addresses - assert config.get_endpoint_type("192.168.1.1") == EndpointType.INTERNAL_IP # Private IPv4 - assert config.get_endpoint_type("10.0.0.1") == EndpointType.INTERNAL_IP # Private IPv4 - assert config.get_endpoint_type("172.16.0.1") == EndpointType.INTERNAL_IP # Private IPv4 - assert config.get_endpoint_type("8.8.8.8") == EndpointType.EXTERNAL_IP # Public IPv4 - assert config.get_endpoint_type("1.1.1.1") == EndpointType.EXTERNAL_IP # Public IPv4 - - # Test IPv6 addresses - result1 = config.get_endpoint_type("::1") - print(f"::1 -> {result1} (expected: {EndpointType.INTERNAL_IP})") - assert result1 == EndpointType.INTERNAL_IP # Loopback IPv6 - - result2 = config.get_endpoint_type("fd00::1") - print(f"fd00::1 -> {result2} (expected: {EndpointType.INTERNAL_IP})") - assert result2 == EndpointType.INTERNAL_IP # Private IPv6 - - result3 = config.get_endpoint_type("2001:db8::1") - print(f"2001:db8::1 -> {result3} (expected: {EndpointType.EXTERNAL_IP})") - assert result3 == EndpointType.EXTERNAL_IP # Public IPv6 - - # Test FQDNs without TLS - assert config.get_endpoint_type("localhost") == EndpointType.INTERNAL_FQDN # Single label - assert config.get_endpoint_type("server.local") == EndpointType.INTERNAL_FQDN # .local domain - assert config.get_endpoint_type("app.internal") == EndpointType.INTERNAL_FQDN # .internal domain - assert config.get_endpoint_type("example.com") == EndpointType.EXTERNAL_FQDN # Public domain - - # Test FQDNs with TLS - assert config.get_endpoint_type("server.local", tls_enabled=True) == EndpointType.INTERNAL_FQDN - assert config.get_endpoint_type("example.com", tls_enabled=True) == EndpointType.EXTERNAL_FQDN - - print("āœ“ Endpoint type detection works correctly") - -def test_override_behavior(): - """Test that explicit endpoint_type overrides automatic detection.""" - print("Testing override behavior...") - - config = MaintenanceEventsConfig(endpoint_type=EndpointType.NONE) - - # Should always return the override value regardless of host - assert config.get_endpoint_type("192.168.1.1") == EndpointType.NONE - assert config.get_endpoint_type("8.8.8.8") == EndpointType.NONE - assert config.get_endpoint_type("example.com") == EndpointType.NONE - assert config.get_endpoint_type("localhost") == EndpointType.NONE - - print("āœ“ Override behavior works correctly") - -def main(): - """Run all tests.""" - print("Running MaintenanceEventsConfig endpoint type tests...\n") - - try: - test_endpoint_type_constants() - test_config_validation() - test_endpoint_type_detection() - test_override_behavior() - - print("\nšŸŽ‰ All tests passed!") - return 0 - except Exception as e: - print(f"\nāŒ Test failed: {e}") - import traceback - traceback.print_exc() - return 1 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index a59b834a4e..a2c10554a2 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -13,6 +13,7 @@ MaintenanceEventPoolHandler, MaintenanceEventConnectionHandler, MaintenanceState, + EndpointType, ) @@ -672,3 +673,153 @@ def test_handle_maintenance_completed_event_success(self): self.mock_connection.reset_tmp_settings.assert_called_once_with( reset_relax_timeout=True ) + + +class TestEndpointType: + """Test the EndpointType class functionality.""" + + def test_endpoint_type_constants(self): + """Test that the EndpointType constants are correct.""" + assert EndpointType.INTERNAL_IP == "internal-ip" + assert EndpointType.INTERNAL_FQDN == "internal-fqdn" + assert EndpointType.EXTERNAL_IP == "external-ip" + assert EndpointType.EXTERNAL_FQDN == "external-fqdn" + assert EndpointType.NONE == "none" + + def test_get_valid_types(self): + """Test that get_valid_types returns the expected set.""" + valid_types = EndpointType.get_valid_types() + expected_types = {"internal-ip", "internal-fqdn", "external-ip", "external-fqdn", "none"} + assert valid_types == expected_types + + +class TestMaintenanceEventsConfigEndpointType: + """Test MaintenanceEventsConfig endpoint type functionality.""" + + def test_config_validation_valid_endpoint_types(self): + """Test that MaintenanceEventsConfig accepts valid endpoint types.""" + for endpoint_type in EndpointType.get_valid_types(): + config = MaintenanceEventsConfig(endpoint_type=endpoint_type) + assert config.endpoint_type == endpoint_type + + def test_config_validation_invalid_endpoint_type(self): + """Test that MaintenanceEventsConfig raises ValueError for invalid endpoint type.""" + with pytest.raises(ValueError, match="Invalid endpoint_type"): + MaintenanceEventsConfig(endpoint_type="invalid-type") + + def test_config_validation_none_endpoint_type(self): + """Test that MaintenanceEventsConfig accepts None as endpoint type.""" + config = MaintenanceEventsConfig(endpoint_type=None) + assert config.endpoint_type is None + + def test_endpoint_type_detection_ip_addresses(self): + """Test endpoint type detection for IP addresses.""" + config = MaintenanceEventsConfig() + + # Mock connection and socket classes + class MockSocket: + def __init__(self, resolved_ip): + self.resolved_ip = resolved_ip + + def getpeername(self): + return (self.resolved_ip, 6379) + + class MockConnection: + def __init__(self, host, resolved_ip=None, is_ssl=False): + self.host = host + self._sock = MockSocket(resolved_ip) if resolved_ip else None + self.__class__.__name__ = 'SSLConnection' if is_ssl else 'Connection' + + # Test private IPv4 addresses + conn1 = MockConnection("192.168.1.1", resolved_ip="192.168.1.1") + assert config.get_endpoint_type("192.168.1.1", conn1) == EndpointType.INTERNAL_IP + + # Test public IPv4 addresses + conn2 = MockConnection("8.8.8.8", resolved_ip="8.8.8.8") + assert config.get_endpoint_type("8.8.8.8", conn2) == EndpointType.EXTERNAL_IP + + # Test IPv6 loopback + conn3 = MockConnection("::1") + assert config.get_endpoint_type("::1", conn3) == EndpointType.INTERNAL_IP + + # Test IPv6 public address + conn4 = MockConnection("2001:4860:4860::8888") + assert config.get_endpoint_type("2001:4860:4860::8888", conn4) == EndpointType.EXTERNAL_IP + + def test_endpoint_type_detection_fqdn_with_resolved_ip(self): + """Test endpoint type detection for FQDNs with resolved IP addresses.""" + config = MaintenanceEventsConfig() + + # Mock connection and socket classes + class MockSocket: + def __init__(self, resolved_ip): + self.resolved_ip = resolved_ip + + def getpeername(self): + return (self.resolved_ip, 6379) + + class MockConnection: + def __init__(self, host, resolved_ip=None, is_ssl=False): + self.host = host + self._sock = MockSocket(resolved_ip) if resolved_ip else None + self.__class__.__name__ = 'SSLConnection' if is_ssl else 'Connection' + + # Test FQDN resolving to private IP + conn1 = MockConnection("redis.internal.company.com", resolved_ip="192.168.1.1") + assert config.get_endpoint_type("redis.internal.company.com", conn1) == EndpointType.INTERNAL_FQDN + + # Test FQDN resolving to public IP + conn2 = MockConnection("db123.redis.com", resolved_ip="8.8.8.8") + assert config.get_endpoint_type("db123.redis.com", conn2) == EndpointType.EXTERNAL_FQDN + + # Test internal FQDN resolving to public IP (should use resolved IP) + conn3 = MockConnection("redis.internal.company.com", resolved_ip="10.8.8.8") + assert config.get_endpoint_type("redis.internal.company.com", conn3) == EndpointType.INTERNAL_FQDN + + # Test FQDN with TLS + conn4 = MockConnection("redis.internal.company.com", resolved_ip="192.168.1.1", is_ssl=True) + assert config.get_endpoint_type("redis.internal.company.com", conn4) == EndpointType.INTERNAL_FQDN + + conn5 = MockConnection("db123.redis.com", resolved_ip="8.8.8.8", is_ssl=True) + assert config.get_endpoint_type("db123.redis.com", conn5) == EndpointType.EXTERNAL_FQDN + + def test_endpoint_type_detection_fqdn_heuristics(self): + """Test endpoint type detection using FQDN heuristics when no resolved IP is available.""" + config = MaintenanceEventsConfig() + + # Mock connection class without resolved IP + class MockConnection: + def __init__(self, host): + self.host = host + self._sock = None + self.__class__.__name__ = 'Connection' + + # Test localhost (should be internal) + conn1 = MockConnection("localhost") + assert config.get_endpoint_type("localhost", conn1) == EndpointType.INTERNAL_FQDN + + # Test .local domain (should be internal) + conn2 = MockConnection("server.local") + assert config.get_endpoint_type("server.local", conn2) == EndpointType.INTERNAL_FQDN + + # Test public domain (should be external) + conn3 = MockConnection("example.com") + assert config.get_endpoint_type("example.com", conn3) == EndpointType.EXTERNAL_FQDN + + def test_endpoint_type_override(self): + """Test that configured endpoint_type overrides detection.""" + # Mock connection class + class MockConnection: + def __init__(self, host): + self.host = host + self._sock = None + self.__class__.__name__ = 'Connection' + + # Test with endpoint_type set to NONE + config = MaintenanceEventsConfig(endpoint_type=EndpointType.NONE) + conn = MockConnection("localhost") + assert config.get_endpoint_type("localhost", conn) == EndpointType.NONE + + # Test with endpoint_type set to EXTERNAL_IP + config = MaintenanceEventsConfig(endpoint_type=EndpointType.EXTERNAL_IP) + assert config.get_endpoint_type("localhost", conn) == EndpointType.EXTERNAL_IP From fb487c0fe73ef0bcfa055ca5bd5d07ad9bec2e54 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Fri, 8 Aug 2025 01:03:20 +0300 Subject: [PATCH 3/8] Adds handshake Signed-off-by: Elena Kolevska --- redis/_parsers/base.py | 101 ++++++++++++++++++++++-------------- redis/connection.py | 93 ++++++++++++++++++++++++++------- redis/maintenance_events.py | 6 +-- 3 files changed, 137 insertions(+), 63 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index d5e4add661..1595a0d007 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -2,6 +2,7 @@ from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Callable, List, Optional, Protocol, Union +import logging from redis.maintenance_events import ( NodeMigratedEvent, @@ -56,6 +57,7 @@ "Client sent AUTH, but no password is set": AuthenticationError, } +logger = logging.getLogger(__name__) class BaseParser(ABC): EXCEPTION_CLASSES = { @@ -192,6 +194,7 @@ def handle_pubsub_push_response(self, response): raise NotImplementedError() def handle_push_response(self, response, **kwargs): + msg_type = response[0] if msg_type not in ( *_INVALIDATION_MESSAGE, @@ -199,28 +202,36 @@ def handle_push_response(self, response, **kwargs): *_MOVING_MESSAGE, ): return self.pubsub_push_handler_func(response) - if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: - return self.invalidation_push_handler_func(response) - if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: - host, port = response[2].decode().split(":") - ttl = response[1] - id = 1 # Hardcoded value until the notification starts including the id - notification = NodeMovingEvent(id, host, port, ttl) - return self.node_moving_push_handler_func(notification) - if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: - if msg_type in _MIGRATING_MESSAGE: - ttl = response[1] - id = 2 # Hardcoded value until the notification starts including the id - notification = NodeMigratingEvent(id, ttl) - elif msg_type in _MIGRATED_MESSAGE: - id = 3 # Hardcoded value until the notification starts including the id - notification = NodeMigratedEvent(id) - else: + + try: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: + return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # Expected message format is: MOVING