From e6d90e41475a6b24751606868986bf72f29ff4b7 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 11 Aug 2025 11:12:46 +0300 Subject: [PATCH 01/50] MultiDbClient implementation (#3696) * Added Database, Healthcheck, CircuitBreaker, FailureDetector * Added DatabaseSelector, exceptions, refactored existing entities * Added MultiDbConfig * Added DatabaseConfig * Added DatabaseConfig test coverage * Renamed DatabaseSelector into FailoverStrategy * Added CommandExecutor * Updated healthcheck to close circuit on success * Added thread-safeness * Added missing thread-safeness * Added missing thread-safenes for dispatcher * Refactored client to keep databases in WeightedList * Added database CRUD operations * Added on-fly configuration * Added background health checks * Added background healthcheck + half-open event * Refactored background scheduling * Refactored healthchecks * Removed code repetitions, fixed weight assignment, added loops enhancement, fixed data structure * Refactored configuration * Refactored failure detector * Refactored retry logic * Added scenario tests * Added pybreaker optional dependency * Added pybreaker to dev dependencies * Rename tests directory * Remove redundant checks * Handle retries if default is not set * Removed all Sentinel related --- dev_requirements.txt | 1 + pyproject.toml | 3 + redis/background.py | 89 +++ redis/client.py | 4 +- redis/data_structure.py | 75 +++ redis/event.py | 65 ++- redis/multidb/__init__.py | 0 redis/multidb/circuit.py | 108 ++++ redis/multidb/client.py | 235 ++++++++ redis/multidb/command_executor.py | 184 ++++++ redis/multidb/config.py | 140 +++++ redis/multidb/database.py | 118 ++++ redis/multidb/event.py | 16 + redis/multidb/exception.py | 2 + redis/multidb/failover.py | 54 ++ redis/multidb/failure_detector.py | 76 +++ redis/multidb/healthcheck.py | 57 ++ tasks.py | 8 +- tests/conftest.py | 6 + tests/test_background.py | 60 ++ tests/test_data_structure.py | 79 +++ tests/test_event.py | 55 ++ tests/test_multidb/__init__.py | 0 tests/test_multidb/conftest.py | 112 ++++ tests/test_multidb/test_circuit.py | 52 ++ tests/test_multidb/test_client.py | 584 +++++++++++++++++++ tests/test_multidb/test_command_executor.py | 160 +++++ tests/test_multidb/test_config.py | 124 ++++ tests/test_multidb/test_failover.py | 117 ++++ tests/test_multidb/test_failure_detector.py | 148 +++++ tests/test_multidb/test_healthcheck.py | 41 ++ tests/test_scenario/__init__.py | 0 tests/test_scenario/conftest.py | 75 +++ tests/test_scenario/fault_injector_client.py | 71 +++ tests/test_scenario/test_active_active.py | 92 +++ 35 files changed, 2995 insertions(+), 16 deletions(-) create mode 100644 redis/background.py create mode 100644 redis/data_structure.py create mode 100644 redis/multidb/__init__.py create mode 100644 redis/multidb/circuit.py create mode 100644 redis/multidb/client.py create mode 100644 redis/multidb/command_executor.py create mode 100644 redis/multidb/config.py create mode 100644 redis/multidb/database.py create mode 100644 redis/multidb/event.py create mode 100644 redis/multidb/exception.py create mode 100644 redis/multidb/failover.py create mode 100644 redis/multidb/failure_detector.py create mode 100644 redis/multidb/healthcheck.py create mode 100644 tests/test_background.py create mode 100644 tests/test_data_structure.py create mode 100644 tests/test_event.py create mode 100644 tests/test_multidb/__init__.py create mode 100644 tests/test_multidb/conftest.py create mode 100644 tests/test_multidb/test_circuit.py create mode 100644 tests/test_multidb/test_client.py create mode 100644 tests/test_multidb/test_command_executor.py create mode 100644 tests/test_multidb/test_config.py create mode 100644 tests/test_multidb/test_failover.py create mode 100644 tests/test_multidb/test_failure_detector.py create mode 100644 tests/test_multidb/test_healthcheck.py create mode 100644 tests/test_scenario/__init__.py create mode 100644 tests/test_scenario/conftest.py create mode 100644 tests/test_scenario/fault_injector_client.py create mode 100644 tests/test_scenario/test_active_active.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 848d6207c4..e61f37f101 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -14,3 +14,4 @@ uvloop vulture>=2.3.0 numpy>=1.24.0 redis-entraid==1.0.0 +pybreaker>=1.4.0 diff --git a/pyproject.toml b/pyproject.toml index ee061953c5..198ac71a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ ocsp = [ jwt = [ "PyJWT>=2.9.0", ] +circuit_breaker = [ + "pybreaker>=1.4.0" +] [project.urls] Changes = "https://github.com/redis/redis-py/releases" diff --git a/redis/background.py b/redis/background.py new file mode 100644 index 0000000000..6466649859 --- /dev/null +++ b/redis/background.py @@ -0,0 +1,89 @@ +import asyncio +import threading +from typing import Callable + +class BackgroundScheduler: + """ + Schedules background tasks execution either in separate thread or in the running event loop. + """ + def __init__(self): + self._next_timer = None + + def __del__(self): + if self._next_timer: + self._next_timer.cancel() + + def run_once(self, delay: float, callback: Callable, *args): + """ + Runs callable task once after certain delay in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later, delay, callback, *args), + daemon=True + ) + thread.start() + + def run_recurring( + self, + interval: float, + callback: Callable, + *args + ): + """ + Runs recurring callable task with given interval in seconds. + """ + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later_recurring, interval, callback, *args), + daemon=True + ) + thread.start() + + def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + self._next_timer = loop.call_later(delay, callback, *args) + + def _call_later_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + def _execute_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args + ): + """ + Executes recurring callable task with given interval in seconds. + """ + callback(*args) + + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + +def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args): + """ + Starts event loop in a thread and schedule callback as soon as event loop is ready. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.call_soon(call_soon_cb, event_loop, *args) + event_loop.run_forever() \ No newline at end of file diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..060fc29493 100755 --- a/redis/client.py +++ b/redis/client.py @@ -603,7 +603,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) - def _close_connection(self, conn) -> None: + def _close_connection(self, conn, error, *args) -> None: """ Close the connection before retrying. @@ -633,7 +633,7 @@ def _execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda _: self._close_connection(conn), + lambda error: self._close_connection(conn, error, *args), ) finally: if self._single_connection_client: diff --git a/redis/data_structure.py b/redis/data_structure.py new file mode 100644 index 0000000000..5b0df7f017 --- /dev/null +++ b/redis/data_structure.py @@ -0,0 +1,75 @@ +import threading +from typing import List, Any, TypeVar, Generic, Union + +from redis.typing import Number + +T = TypeVar('T') + +class WeightedList(Generic[T]): + """ + Thread-safe weighted list. + """ + def __init__(self): + self._items: List[tuple[Any, Number]] = [] + self._lock = threading.RLock() + + def add(self, item: Any, weight: float) -> None: + """Add item with weight, maintaining sorted order""" + with self._lock: + # Find insertion point using binary search + left, right = 0, len(self._items) + while left < right: + mid = (left + right) // 2 + if self._items[mid][1] < weight: + right = mid + else: + left = mid + 1 + + self._items.insert(left, (item, weight)) + + def remove(self, item): + """Remove first occurrence of item""" + with self._lock: + for i, (stored_item, weight) in enumerate(self._items): + if stored_item == item: + self._items.pop(i) + return weight + raise ValueError("Item not found") + + def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Number]]: + """Get all items within weight range""" + with self._lock: + result = [] + for item, weight in self._items: + if min_weight <= weight <= max_weight: + result.append((item, weight)) + return result + + def get_top_n(self, n: int) -> List[tuple[Any, Number]]: + """Get top N the highest weighted items""" + with self._lock: + return [(item, weight) for item, weight in self._items[:n]] + + def update_weight(self, item, new_weight: float): + with self._lock: + """Update weight of an item""" + old_weight = self.remove(item) + self.add(item, new_weight) + return old_weight + + def __iter__(self): + """Iterate in descending weight order""" + with self._lock: + items_copy = self._items.copy() # Create snapshot as lock released after each 'yield' + + for item, weight in items_copy: + yield item, weight + + def __len__(self): + with self._lock: + return len(self._items) + + def __getitem__(self, index) -> tuple[Any, Number]: + with self._lock: + item, weight = self._items[index] + return item, weight \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index b86c66b082..03480364db 100644 --- a/redis/event.py +++ b/redis/event.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Type from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider @@ -42,6 +42,11 @@ def dispatch(self, event: object): async def dispatch_async(self, event: object): pass + @abstractmethod + def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + """Register additional listeners.""" + pass + class EventException(Exception): """ @@ -56,11 +61,14 @@ def __init__(self, exception: Exception, event: object): class EventDispatcher(EventDispatcherInterface): # TODO: Make dispatcher to accept external mappings. - def __init__(self): + def __init__( + self, + event_listeners: Optional[Dict[Type[object], List[EventListenerInterface]]] = None, + ): """ - Mapping should be extended for any new events or listeners to be added. + Dispatcher that dispatches events to listeners associated with given event. """ - self._event_listeners_mapping = { + self._event_listeners_mapping: Dict[Type[object], List[EventListenerInterface]]= { AfterConnectionReleasedEvent: [ ReAuthConnectionListener(), ], @@ -77,17 +85,35 @@ def __init__(self): ], } + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + if event_listeners: + self.register_listeners(event_listeners) + def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + with self._lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - listener.listen(event) + for listener in listeners: + listener.listen(event) async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + with self._async_lock: + listeners = self._event_listeners_mapping.get(type(event), []) + + for listener in listeners: + await listener.listen(event) - for listener in listeners: - await listener.listen(event) + def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + with self._lock: + for event_type in event_listeners: + if event_type in self._event_listeners_mapping: + self._event_listeners_mapping[event_type] = list( + set(self._event_listeners_mapping[event_type] + event_listeners[event_type]) + ) + else: + self._event_listeners_mapping[event_type] = event_listeners[event_type] class AfterConnectionReleasedEvent: @@ -225,6 +251,25 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider +class OnCommandFailEvent: + """ + Event fired whenever a command fails during the execution. + """ + def __init__( + self, + command: tuple, + exception: Exception, + ): + self._command = command + self._exception = exception + + @property + def command(self) -> tuple: + return self._command + + @property + def exception(self) -> Exception: + return self._exception class ReAuthConnectionListener(EventListenerInterface): """ diff --git a/redis/multidb/__init__.py b/redis/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py new file mode 100644 index 0000000000..9211173c83 --- /dev/null +++ b/redis/multidb/circuit.py @@ -0,0 +1,108 @@ +from abc import abstractmethod, ABC +from enum import Enum +from typing import Callable + +import pybreaker + +class State(Enum): + CLOSED = 'closed' + OPEN = 'open' + HALF_OPEN = 'half-open' + +class CircuitBreaker(ABC): + @property + @abstractmethod + def grace_period(self) -> float: + """The grace period in seconds when the circle should be kept open.""" + pass + + @grace_period.setter + @abstractmethod + def grace_period(self, grace_period: float): + """Set the grace period in seconds.""" + + @property + @abstractmethod + def state(self) -> State: + """The current state of the circuit.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set current state of the circuit.""" + pass + + @property + @abstractmethod + def database(self): + """Database associated with this circuit.""" + pass + + @database.setter + @abstractmethod + def database(self, database): + """Set database associated with this circuit.""" + pass + + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + +class PBListener(pybreaker.CircuitBreakerListener): + def __init__( + self, + cb: Callable[[CircuitBreaker, State, State], None], + database, + ): + """Wrapper for callback to be compatible with pybreaker implementation.""" + self._cb = cb + self._database = database + + def state_change(self, cb, old_state, new_state): + cb = PBCircuitBreakerAdapter(cb) + cb.database = self._database + old_state = State(value=old_state.name) + new_state = State(value=new_state.name) + self._cb(cb, old_state, new_state) + + +class PBCircuitBreakerAdapter(CircuitBreaker): + def __init__(self, cb: pybreaker.CircuitBreaker): + """Adapter for pybreaker CircuitBreaker.""" + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + listener = PBListener(cb, self.database) + self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py new file mode 100644 index 0000000000..702a7da01d --- /dev/null +++ b/redis/multidb/client.py @@ -0,0 +1,235 @@ +import threading +import socket +from typing import Callable + +from redis.background import BackgroundScheduler +from redis.exceptions import ConnectionError, TimeoutError +from redis.commands import RedisModuleCommands, CoreCommands +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck + + +class MultiDBClient(RedisModuleCommands, CoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() if config.health_checks is None else config.health_checks + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.default_failure_detectors() \ + if config.failure_detectors is None else config.failure_detectors + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=config.command_retry, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + + for fd in self._failure_detectors: + fd.set_command_executor(command_executor=self._command_executor) + + self._initialized = False + self._hc_lock = threading.RLock() + self._bg_scheduler = BackgroundScheduler() + + def _initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + + def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + self._bg_scheduler.run_recurring( + self._health_check_interval, + self._check_databases_health, + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + database.state = DBState.ACTIVE + self._command_executor.active_database = database + is_active_db_found = True + elif database.circuit.state == CBState.CLOSED and is_active_db_found: + database.state = DBState.PASSIVE + else: + database.state = DBState.DISCONNECTED + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self._initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + def set_active_database(self, database: AbstractDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + highest_weighted_db.state = DBState.PASSIVE + database.state = DBState.ACTIVE + self._command_executor.active_database = database + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + def add_database(self, database: AbstractDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + self._change_active_database(database, highest_weighted_db) + + def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + new_database.state = DBState.ACTIVE + self._command_executor.active_database = new_database + highest_weight_database.state = DBState.PASSIVE + + def remove_database(self, database: Database): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + highest_weighted_db.state = DBState.ACTIVE + self._command_executor.active_database = highest_weighted_db + + def update_database_weight(self, database: AbstractDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: FailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + with self._hc_lock: + self._health_checks.append(healthcheck) + + def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self._initialized: + self._initialize() + + return self._command_executor.execute_command(*args, **options) + + def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + with self._hc_lock: + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except (ConnectionError, TimeoutError, socket.timeout) as e: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + if on_error: + on_error(e) + + + def _check_databases_health(self, on_error: Callable[[Exception], None] = None): + """ + Runs health checks as a recurring task. + """ + for database, _ in self._databases: + self._check_db_health(database, on_error) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + if new_state == CBState.HALF_OPEN: + self._check_db_health(circuit.database) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py new file mode 100644 index 0000000000..0783f6da82 --- /dev/null +++ b/redis/multidb/command_executor.py @@ -0,0 +1,184 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from redis.event import EventDispatcherInterface, OnCommandFailEvent +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.multidb.event import RegisterCommandFailure +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.retry import Retry + + +class CommandExecutor(ABC): + + @property + @abstractmethod + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def active_database(self) -> Union[Database, None]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: AbstractDatabase) -> None: + """Sets currently active database.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> FailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" + pass + + @auto_fallback_interval.setter + @abstractmethod + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + +class DefaultCommandExecutor(CommandExecutor): + + def __init__( + self, + failure_detectors: List[FailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: FailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + :param failure_detectors: List of failure detectors. + :param databases: List of databases. + :param failover_strategy: Strategy that defines the failover logic. + :param event_dispatcher: Event dispatcher. + :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to + failover_strategy. + """ + self._failure_detectors = failure_detectors + self._databases = databases + self._command_retry = command_retry + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + self._active_database: Union[Database, None] = None + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def failure_detectors(self) -> List[FailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def databases(self) -> Databases: + return self._databases + + @property + def command_retry(self) -> Retry: + return self._command_retry + + @property + def active_database(self) -> Optional[AbstractDatabase]: + return self._active_database + + @active_database.setter + def active_database(self, database: AbstractDatabase) -> None: + self._active_database = database + + @property + def failover_strategy(self) -> FailoverStrategy: + return self._failover_strategy + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def execute_command(self, *args, **options): + self._check_active_database() + + return self._command_retry.call_with_retry( + lambda: self._execute_command(*args, **options), + lambda error: self._on_command_fail(error, *args), + ) + + def _execute_command(self, *args, **options): + self._check_active_database() + return self._active_database.client.execute_command(*args, **options) + + def _on_command_fail(self, error, *args): + self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) + + def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + self._active_database = self._failover_strategy.database + self._schedule_next_fallback() + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + + def _setup_event_dispatcher(self): + """ + Registers command failure event listener. + """ + event_listener = RegisterCommandFailure(self._failure_detectors) + self._event_dispatcher.register_listeners({ + OnCommandFailEvent: [event_listener], + }) \ No newline at end of file diff --git a/redis/multidb/config.py b/redis/multidb/config.py new file mode 100644 index 0000000000..64ad7c9052 --- /dev/null +++ b/redis/multidb/config.py @@ -0,0 +1,140 @@ +from dataclasses import dataclass, field +from typing import List, Type, Union + +import pybreaker +from typing_extensions import Optional + +from redis import Redis, ConnectionPool +from redis.asyncio import RedisCluster +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcher, EventDispatcherInterface +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.database import Database, Databases +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy +from redis.retry import Retry + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of failure detectors for monitoring database failures. + failure_threshold: Threshold for determining database failure. + failures_interval: Time interval for tracking database failures. + health_checks: Optional list of health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_retries: Number of retry attempts for performing health checks. + health_check_backoff: Backoff strategy for health check retries. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_retries: Number of retries allowed for failover operations. + failover_backoff: Backoff strategy for failover retries. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[FailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[FailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + + if database_config.from_url: + client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + elif database_config.from_pool: + database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) + client = self.client_class.from_pool(connection_pool=database_config.from_pool) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit + databases.add( + Database(client=client, circuit=circuit, weight=database_config.weight), + database_config.weight + ) + + return databases + + def default_failure_detectors(self) -> List[FailureDetector]: + return [ + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> FailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) diff --git a/redis/multidb/database.py b/redis/multidb/database.py new file mode 100644 index 0000000000..15db52e909 --- /dev/null +++ b/redis/multidb/database.py @@ -0,0 +1,118 @@ +import redis +from abc import ABC, abstractmethod +from enum import Enum +from typing import Union + +from redis import RedisCluster +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.typing import Number + + +class State(Enum): + ACTIVE = 0 + PASSIVE = 1 + DISCONNECTED = 2 + +class AbstractDatabase(ABC): + @property + @abstractmethod + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" + pass + + @weight.setter + @abstractmethod + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """The state of the current database.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set the state of the current database.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AbstractDatabase, Number]] + +class Database(AbstractDatabase): + def __init__( + self, + client: Union[redis.Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + state: State = State.DISCONNECTED, + ): + """ + param: client: Client instance for communication with the database. + param: circuit: Circuit breaker for the current database. + param: weight: Weight of current database. Database with the highest weight becomes Active. + param: state: State of the current database. + """ + self._client = client + self._cb = circuit + self._cb.database = self + self._weight = weight + self._state = state + + @property + def client(self) -> Union[redis.Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[redis.Redis, RedisCluster]): + self._client = client + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, state: State): + self._state = state + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit diff --git a/redis/multidb/event.py b/redis/multidb/event.py new file mode 100644 index 0000000000..3a5ed3ec24 --- /dev/null +++ b/redis/multidb/event.py @@ -0,0 +1,16 @@ +from typing import List + +from redis.event import EventListenerInterface, OnCommandFailEvent +from redis.multidb.failure_detector import FailureDetector + + +class RegisterCommandFailure(EventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[FailureDetector]): + self._failure_detectors = failure_detectors + + def listen(self, event: OnCommandFailEvent) -> None: + for failure_detector in self._failure_detectors: + failure_detector.register_failure(event.exception, event.command) diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py new file mode 100644 index 0000000000..80fdb9409a --- /dev/null +++ b/redis/multidb/exception.py @@ -0,0 +1,2 @@ +class NoValidDatabaseException(Exception): + pass \ No newline at end of file diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py new file mode 100644 index 0000000000..a4c825aac1 --- /dev/null +++ b/redis/multidb/failover.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod + +from redis.data_structure import WeightedList +from redis.multidb.database import Databases +from redis.multidb.database import AbstractDatabase +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.retry import Retry + + +class FailoverStrategy(ABC): + + @property + @abstractmethod + def database(self) -> AbstractDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the databases strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(FailoverStrategy): + """ + Choose the active database with the highest weight. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + @property + def database(self) -> AbstractDatabase: + return self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: self._dummy_fail() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + def _get_active_database(self) -> AbstractDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') + + def _dummy_fail(self): + pass diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py new file mode 100644 index 0000000000..50f1c839bd --- /dev/null +++ b/redis/multidb/failure_detector.py @@ -0,0 +1,76 @@ +import threading +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Type + +from typing_extensions import Optional + +from redis.multidb.circuit import State as CBState + + +class FailureDetector(ABC): + + @abstractmethod + def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + +class CommandFailureDetector(FailureDetector): + """ + Detects a failure based on a threshold of failed commands during a specific period of time. + """ + + def __init__( + self, + threshold: int, + duration: float, + error_types: Optional[List[Type[Exception]]] = None, + ) -> None: + """ + :param threshold: Threshold of failed commands over the duration after which database will be marked as failed. + :param duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. + :param error_types: List of exception that has to be registered. By default, all exceptions are registered. + """ + self._command_executor = None + self._threshold = threshold + self._duration = duration + self._error_types = error_types + self._start_time: datetime = datetime.now() + self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration: List[tuple[datetime, tuple]] = [] + self._lock = threading.RLock() + + def register_failure(self, exception: Exception, cmd: tuple) -> None: + failure_time = datetime.now() + + if not self._start_time < failure_time < self._end_time: + self._reset() + + with self._lock: + if self._error_types: + if type(exception) in self._error_types: + self._failures_within_duration.append((datetime.now(), cmd)) + else: + self._failures_within_duration.append((datetime.now(), cmd)) + + self._check_threshold() + + def set_command_executor(self, command_executor) -> None: + self._command_executor = command_executor + + def _check_threshold(self): + with self._lock: + if len(self._failures_within_duration) >= self._threshold: + self._command_executor.active_database.circuit.state = CBState.OPEN + self._reset() + + def _reset(self) -> None: + with self._lock: + self._start_time = datetime.now() + self._end_time = self._start_time + timedelta(seconds=self._duration) + self._failures_within_duration = [] \ No newline at end of file diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py new file mode 100644 index 0000000000..a96b9cf815 --- /dev/null +++ b/redis/multidb/healthcheck.py @@ -0,0 +1,57 @@ +from abc import abstractmethod, ABC +from redis.retry import Retry + + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry, + ) -> None: + self._retry = retry + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + def check_health(self, database) -> bool: + pass + + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry, + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + def check_health(self, database) -> bool: + return self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: self._dummy_fail() + ) + + def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + actual_message = database.client.execute_command('ECHO', "healthcheck") + return actual_message in expected_message + + def _dummy_fail(self): + pass \ No newline at end of file diff --git a/tasks.py b/tasks.py index 52decf08e7..5c39eb0e43 100644 --- a/tasks.py +++ b/tasks.py @@ -58,11 +58,11 @@ def standalone_tests( if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=scenario --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" ) else: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=scenario --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" ) @@ -74,11 +74,11 @@ def cluster_tests(c, uvloop=False, protocol=2, profile=False): cluster_tls_url = "rediss://localhost:27379/0" if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" + f"pytest {profile_arg} --protocol={protocol} --ignore=scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" ) else: run( - f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} --ignore=scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" ) diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..fc316ea720 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ ) from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.credentials import CredentialProvider +from redis.event import EventDispatcherInterface from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_tls_certificates @@ -581,6 +582,11 @@ def mock_connection() -> ConnectionInterface: mock_connection = Mock(spec=ConnectionInterface) return mock_connection +@pytest.fixture() +def mock_ed() -> EventDispatcherInterface: + mock_ed = Mock(spec=EventDispatcherInterface) + return mock_ed + @pytest.fixture() def cache_key(request) -> CacheKey: diff --git a/tests/test_background.py b/tests/test_background.py new file mode 100644 index 0000000000..4b3a5377c1 --- /dev/null +++ b/tests/test_background.py @@ -0,0 +1,60 @@ +from time import sleep + +import pytest + +from redis.background import BackgroundScheduler + +class TestBackgroundScheduler: + def test_run_once(self): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_once(0.1, callback, one, two) + assert execute_counter == 0 + + sleep(0.15) + + assert execute_counter == 1 + + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + def test_run_recurring(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_recurring(interval, callback, one, two) + assert execute_counter == 0 + + sleep(timeout) + + assert execute_counter == call_count \ No newline at end of file diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py new file mode 100644 index 0000000000..31ac5c4316 --- /dev/null +++ b/tests/test_data_structure.py @@ -0,0 +1,79 @@ +import concurrent +import random +from concurrent.futures import ThreadPoolExecutor +from time import sleep + +from redis.data_structure import WeightedList + + +class TestWeightedList: + def test_add_items(self): + wlist = WeightedList() + + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + def test_remove_items(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.remove('item2') == 2.0 + assert wlist.remove('item4') == 4.0 + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item1', 3.0)] + + def test_get_by_weight_range(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_by_weight_range(2.0, 3.0) == [('item1', 3.0), ('item2', 2.0)] + + def test_update_weights(self): + wlist = WeightedList() + wlist.add('item1', 3.0) + wlist.add('item2', 2.0) + wlist.add('item3', 4.0) + wlist.add('item4', 4.0) + + assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + + wlist.update_weight('item2', 5.0) + + assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] + + def test_thread_safety(self) -> None: + """Test thread safety with concurrent operations""" + wl = WeightedList() + + def worker(worker_id): + for i in range(100): + # Add items + wl.add(f"item_{worker_id}_{i}", random.uniform(0, 100)) + + # Read operations + try: + length = len(wl) + if length > 0: + top_items = wl.get_top_n(min(5, length)) + items_in_range = wl.get_by_weight_range(20, 80) + except Exception as e: + print(f"Error in worker {worker_id}: {e}") + + sleep(0.001) # Small delay + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + concurrent.futures.wait(futures) + + assert len(wl) == 500 \ No newline at end of file diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 0000000000..27526abeaf --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,55 @@ +from unittest.mock import Mock, AsyncMock + +from redis.event import EventListenerInterface, EventDispatcher, AsyncEventListenerInterface + + +class TestEventDispatcher: + def test_register_listeners(self): + mock_event = Mock(spec=object) + mock_event_listener = Mock(spec=EventListenerInterface) + listener_called = 0 + + def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=EventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + dispatcher.dispatch(mock_event) + + assert listener_called == 3 + + async def test_register_listeners_async(self): + mock_event = Mock(spec=object) + mock_event_listener = AsyncMock(spec=AsyncEventListenerInterface) + listener_called = 0 + + async def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 3 \ No newline at end of file diff --git a/tests/test_multidb/__init__.py b/tests/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py new file mode 100644 index 0000000000..ad2057a118 --- /dev/null +++ b/tests/test_multidb/conftest.py @@ -0,0 +1,112 @@ +from unittest.mock import Mock + +import pytest + +from redis import Redis +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, State, Databases +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck +from tests.conftest import mock_ed + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + +@pytest.fixture() +def mock_fd() -> FailureDetector: + return Mock(spec=FailureDetector) + +@pytest.fixture() +def mock_fs() -> FailoverStrategy: + return Mock(spec=FailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.state = request.param.get("state", State.ACTIVE) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_checks=[mock_hc], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py new file mode 100644 index 0000000000..7dc642373b --- /dev/null +++ b/tests/test_multidb/test_circuit.py @@ -0,0 +1,52 @@ +import pybreaker +import pytest + +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker + + +class TestPBCircuitBreaker: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + ], + indirect=True, + ) + def test_cb_correctly_configured(self, mock_db): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + adapter.database = mock_db + assert adapter.database == mock_db + + def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py new file mode 100644 index 0000000000..b94c4ce61e --- /dev/null +++ b/tests/test_multidb/test_client.py @@ -0,0 +1,584 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.event import EventDispatcher, OnCommandFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF +from redis.multidb.database import State as DBState, AbstractDatabase +from redis.multidb.client import MultiDBClient +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + sleep(0.15) + assert client.set('key', 'value') == 'OK2' + sleep(0.1) + assert client.set('key', 'value') == 'OK' + sleep(0.1) + assert client.set('key', 'value') == 'OK1' + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.15) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + sleep(0.22) + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + client.set('key', 'value') + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.DISCONNECTED + assert mock_db1.state == DBState.DISCONNECTED + assert mock_db2.state == DBState.DISCONNECTED + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + client.add_database(mock_db) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK2' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 2 + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + client.add_database(mock_db1) + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert client.set('key', 'value') == 'OK1' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.remove_database(mock_db1) + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert client.set('key', 'value') == 'OK2' + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.ACTIVE + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandFailEvent( + command=('SET', 'key', 'value'), + exception=Exception(), + ) + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=FailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + client.add_health_check(another_hc) + client._check_db_health(mock_db1) + + assert another_hc.check_health.call_count == 1 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set('key', 'value') == 'OK1' + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.state == DBState.PASSIVE + assert mock_db1.state == DBState.ACTIVE + assert mock_db2.state == DBState.PASSIVE + + client.set_active_database(mock_db) + assert client.set('key', 'value') == 'OK' + + assert mock_db.state == DBState.ACTIVE + assert mock_db1.state == DBState.PASSIVE + assert mock_db2.state == DBState.PASSIVE + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + client.set_active_database(Mock(spec=AbstractDatabase)) + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..675f9d442f --- /dev/null +++ b/tests/test_multidb/test_command_executor.py @@ -0,0 +1,160 @@ +from time import sleep +from unittest.mock import PropertyMock + +import pytest + +from redis.exceptions import ConnectionError +from redis.backoff import NoBackoff +from redis.event import EventDispatcher +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.failure_detector import CommandFailureDetector +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + executor.active_database = mock_db1 + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + + executor.active_database = mock_db2 + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) + ) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + sleep(0.15) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command.side_effect = ['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1'] + mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] + mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = CommandFailureDetector(threshold, 1) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py new file mode 100644 index 0000000000..87aae701a9 --- /dev/null +++ b/tests/test_multidb/test_config.py @@ -0,0 +1,124 @@ +from unittest.mock import Mock +from redis.connection import ConnectionPool +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD +from redis.multidb.database import Database +from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector +from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.multidb.failover import WeightBasedFailoverStrategy, FailoverStrategy +from redis.retry import Retry + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i+=1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], CommandFailureDetector) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=FailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), PBCircuitBreakerAdapter) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py new file mode 100644 index 0000000000..06390c4e2e --- /dev/null +++ b/tests/test_multidb/test_failover.py @@ -0,0 +1,117 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.retry import Retry + + +class TestWeightBasedFailoverStrategy: + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert failover_strategy.database == mock_db1 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert failover_strategy.database == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database + + assert state_mock.call_count == 4 + + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database \ No newline at end of file diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..86d6e1cd82 --- /dev/null +++ b/tests/test_multidb/test_failure_detector.py @@ -0,0 +1,148 @@ +from time import sleep +from unittest.mock import Mock + +import pytest + +from redis.multidb.command_executor import CommandExecutor +from redis.multidb.failure_detector import CommandFailureDetector +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestCommandFailureDetector: + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = CommandFailureDetector(5, 0.3) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.1) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failure as last one already refreshed timer + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = CommandFailureDetector(5, 0.3) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + mock_ce = Mock(spec=CommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..9601638913 --- /dev/null +++ b/tests/test_multidb/test_healthcheck.py @@ -0,0 +1,41 @@ +from redis.backoff import ExponentialBackoff +from redis.multidb.database import Database, State +from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError +from redis.retry import Retry + + +class TestEchoHealthCheck: + def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + + assert hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_scenario/__init__.py b/tests/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py new file mode 100644 index 0000000000..8ae7441e98 --- /dev/null +++ b/tests/test_scenario/conftest.py @@ -0,0 +1,75 @@ +import json +import os + +import pytest + +from redis.backoff import NoBackoff +from redis.multidb.client import MultiDBClient +from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_FAILURES_THRESHOLD +from redis.retry import Retry +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +def get_endpoint_config(endpoint_name: str): + endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) + + if not (endpoints_config and os.path.exists(endpoints_config)): + raise FileNotFoundError(f"Endpoints config file not found: {endpoints_config}") + + try: + with open(endpoints_config, "r") as f: + data = json.load(f) + db = data[endpoint_name] + return db + except Exception as e: + raise ValueError( + f"Failed to load endpoints config file: {endpoints_config}" + ) from e + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + +@pytest.fixture() +def r_multi_db(request) -> MultiDBClient: + endpoint_config = get_endpoint_config('re-active-active') + username = endpoint_config.get('username', None) + password = endpoint_config.get('password', None) + failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + command_retry = request.param.get('command_retry', Retry(NoBackoff(), retries=3)) + health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config['endpoints'][0], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + } + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config['endpoints'][1], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + } + ) + db_configs.append(db_config1) + + + config = MultiDbConfig( + databases_config=db_configs, + command_retry=command_retry, + failure_threshold=failure_threshold, + health_check_interval=health_check_interval, + ) + + return MultiDBClient(config) \ No newline at end of file diff --git a/tests/test_scenario/fault_injector_client.py b/tests/test_scenario/fault_injector_client.py new file mode 100644 index 0000000000..aa8d27819a --- /dev/null +++ b/tests/test_scenario/fault_injector_client.py @@ -0,0 +1,71 @@ +import json +import urllib.request +from typing import Dict, Any, Optional, Union +from enum import Enum + + +class ActionType(str, Enum): + DMC_RESTART = "dmc_restart" + FAILOVER = "failover" + RESHARD = "reshard" + SEQUENCE_OF_ACTIONS = "sequence_of_actions" + NETWORK_FAILURE = "network_failure" + EXECUTE_RLUTIL_COMMAND = "execute_rlutil_command" + + +class RestartDmcParams: + def __init__(self, bdb_id: str): + self.bdb_id = bdb_id + + def to_dict(self) -> Dict[str, str]: + return {"bdb_id": self.bdb_id} + + +class ActionRequest: + def __init__(self, action_type: ActionType, parameters: Union[Dict[str, Any], RestartDmcParams]): + self.type = action_type + self.parameters = parameters + + def to_dict(self) -> Dict[str, Any]: + return { + "type": self.type, + "parameters": self.parameters.to_dict() if isinstance(self.parameters, + RestartDmcParams) else self.parameters + } + + +class FaultInjectorClient: + def __init__(self, base_url: str): + self.base_url = base_url.rstrip('/') + + def _make_request(self, method: str, path: str, data: Optional[Dict] = None) -> Dict[str, Any]: + url = f"{self.base_url}{path}" + headers = {"Content-Type": "application/json"} if data else {} + + request = urllib.request.Request( + url, + method=method, + data=json.dumps(data).encode('utf-8') if data else None, + headers=headers + ) + + try: + with urllib.request.urlopen(request) as response: + return json.loads(response.read().decode('utf-8')) + except urllib.error.HTTPError as e: + if e.code == 422: + error_body = json.loads(e.read().decode('utf-8')) + raise ValueError(f"Validation Error: {error_body}") + raise + + def list_actions(self) -> Dict[str, Any]: + """List all available actions""" + return self._make_request("GET", "/action") + + def trigger_action(self, action_request: ActionRequest) -> Dict[str, Any]: + """Trigger a new action""" + return self._make_request("POST", "/action", action_request.to_dict()) + + def get_action_status(self, action_id: str) -> Dict[str, Any]: + """Get the status of a specific action""" + return self._make_request("GET", f"/action/{action_id}") diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py new file mode 100644 index 0000000000..93e251ed4b --- /dev/null +++ b/tests/test_scenario/test_active_active.py @@ -0,0 +1,92 @@ +import logging +import threading +from time import sleep + +import pytest + +from redis.backoff import NoBackoff +from redis.exceptions import ConnectionError +from redis.retry import Retry +from tests.test_scenario.conftest import get_endpoint_config +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + +def trigger_network_failure_action(fault_injector_client, event: threading.Event = None): + endpoint_config = get_endpoint_config('re-active-active') + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 3, "cluster_index": 0} + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result['action_id']) + + while status_result['status'] != "success": + sleep(0.1) + status_result = fault_injector_client.get_action_status(result['action_id']) + logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + +class TestActiveActiveStandalone: + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + thread.start() + + r_multi_db.set('key', 'value') + current_active_db = r_multi_db._command_executor.active_database + + # Execute commands before network failure + while not event.is_set(): + assert r_multi_db.get('key') == 'value' + sleep(0.1) + + # Active db has been changed. + assert current_active_db != r_multi_db._command_executor.active_database + + # Execute commands after network failure + for _ in range(3): + assert r_multi_db.get('key') == 'value' + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + { + "failure_threshold": 15, + "command_retry": Retry(NoBackoff(), retries=5), + "health_check_interval": 100, + } + ], + indirect=True + ) + def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + thread.start() + + with pytest.raises(ConnectionError): + # Retries count > failure threshold, so a client gives up earlier. + while not event.is_set(): + assert r_multi_db.get('key') == 'value' + sleep(0.1) \ No newline at end of file From 8c09cbe680a9e21ab507c9b2352d9b8d3e7b4c77 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 11 Aug 2025 11:44:26 +0300 Subject: [PATCH 02/50] Added support for Pipeline and transactions (#3707) * Added Database, Healthcheck, CircuitBreaker, FailureDetector * Added DatabaseSelector, exceptions, refactored existing entities * Added MultiDbConfig * Added DatabaseConfig * Added DatabaseConfig test coverage * Renamed DatabaseSelector into FailoverStrategy * Added CommandExecutor * Updated healthcheck to close circuit on success * Added thread-safeness * Added missing thread-safeness * Added missing thread-safenes for dispatcher * Refactored client to keep databases in WeightedList * Added database CRUD operations * Added on-fly configuration * Added background health checks * Added background healthcheck + half-open event * Refactored background scheduling * Added support for Active-Active pipeline * Refactored healthchecks * Added Pipeline testing * Added support for transactions * Removed code repetitions, fixed weight assignment, added loops enhancement, fixed data structure * Added missing doc blocks * Refactored configuration * Refactored failure detector * Refactored retry logic * Added scenario tests * Added pybreaker optional dependency * Added pybreaker to dev dependencies * Rename tests directory * Added scenario tests for Pipeline and Transaction * Added handling of ConnectionRefusedError, added timeouts so cluster could recover * Increased timeouts * Refactored integration tests * Fixed property name * Removed sentinels * Removed unused method --- redis/event.py | 10 +- redis/multidb/client.py | 116 +++++-- redis/multidb/command_executor.py | 56 +++- redis/multidb/event.py | 6 +- tests/test_multidb/test_client.py | 6 +- tests/test_multidb/test_pipeline.py | 352 ++++++++++++++++++++++ tests/test_scenario/test_active_active.py | 160 +++++++++- 7 files changed, 655 insertions(+), 51 deletions(-) create mode 100644 tests/test_multidb/test_pipeline.py diff --git a/redis/event.py b/redis/event.py index 03480364db..1fa66f0587 100644 --- a/redis/event.py +++ b/redis/event.py @@ -251,21 +251,21 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider -class OnCommandFailEvent: +class OnCommandsFailEvent: """ Event fired whenever a command fails during the execution. """ def __init__( self, - command: tuple, + commands: tuple, exception: Exception, ): - self._command = command + self._commands = commands self._exception = exception @property - def command(self) -> tuple: - return self._command + def commands(self) -> tuple: + return self._commands @property def exception(self) -> Exception: diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 702a7da01d..367060bcc3 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,6 +1,6 @@ import threading import socket -from typing import Callable +from typing import List, Any, Callable from redis.background import BackgroundScheduler from redis.exceptions import ConnectionError, TimeoutError @@ -30,23 +30,22 @@ def __init__(self, config: MultiDbConfig): self._failover_strategy.set_databases(self._databases) self._auto_fallback_interval = config.auto_fallback_interval self._event_dispatcher = config.event_dispatcher - self._command_executor = DefaultCommandExecutor( + self._command_retry = config.command_retry + self._command_retry.update_supported_errors((ConnectionRefusedError,)) + self.command_executor = DefaultCommandExecutor( failure_detectors=self._failure_detectors, databases=self._databases, - command_retry=config.command_retry, + command_retry=self._command_retry, failover_strategy=self._failover_strategy, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) - - for fd in self._failure_detectors: - fd.set_command_executor(command_executor=self._command_executor) - - self._initialized = False + self.initialized = False self._hc_lock = threading.RLock() self._bg_scheduler = BackgroundScheduler() + self._config = config - def _initialize(self): + def initialize(self): """ Perform initialization of databases to define their initial state. """ @@ -72,7 +71,7 @@ def raise_exception_on_failed_hc(error): # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db_found: database.state = DBState.ACTIVE - self._command_executor.active_database = database + self.command_executor.active_database = database is_active_db_found = True elif database.circuit.state == CBState.CLOSED and is_active_db_found: database.state = DBState.PASSIVE @@ -82,7 +81,7 @@ def raise_exception_on_failed_hc(error): if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') - self._initialized = True + self.initialized = True def get_databases(self) -> Databases: """ @@ -110,7 +109,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: highest_weighted_db, _ = self._databases.get_top_n(1)[0] highest_weighted_db.state = DBState.PASSIVE database.state = DBState.ACTIVE - self._command_executor.active_database = database + self.command_executor.active_database = database return raise NoValidDatabaseException('Cannot set active database, database is unhealthy') @@ -132,7 +131,7 @@ def add_database(self, database: AbstractDatabase): def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: new_database.state = DBState.ACTIVE - self._command_executor.active_database = new_database + self.command_executor.active_database = new_database highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): @@ -144,7 +143,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: highest_weighted_db.state = DBState.ACTIVE - self._command_executor.active_database = highest_weighted_db + self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: AbstractDatabase, weight: float): """ @@ -182,10 +181,25 @@ def execute_command(self, *args, **options): """ Executes a single command and return its result. """ - if not self._initialized: - self._initialize() + if not self.initialized: + self.initialize() + + return self.command_executor.execute_command(*args, **options) + + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) - return self._command_executor.execute_command(*args, **options) + def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): + """ + Executes callable as transaction. + """ + if not self.initialized: + self.initialize() + + return self.command_executor.execute_transaction(func, *watches, *options) def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: """ @@ -207,7 +221,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout) as e: + except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False @@ -219,7 +233,9 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep def _check_databases_health(self, on_error: Callable[[Exception], None] = None): """ Runs health checks as a recurring task. + Runs health checks against all databases. """ + for database, _ in self._databases: self._check_db_health(database, on_error) @@ -232,4 +248,66 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN \ No newline at end of file + circuit.state = CBState.HALF_OPEN + + +class Pipeline(RedisModuleCommands, CoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + def __enter__(self) -> "Pipeline": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + def reset(self) -> None: + self._command_stack = [] + + def close(self) -> None: + """Close the pipeline""" + self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + return self.pipeline_execute_command(*args, **kwargs) + + def execute(self) -> List[Any]: + if not self._client.initialized: + self._client.initialize() + + try: + return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + finally: + self.reset() diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 0783f6da82..690ea49a5c 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union, Optional +from typing import List, Union, Optional, Callable -from redis.event import EventDispatcherInterface, OnCommandFailEvent +from redis.client import Pipeline +from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState @@ -92,6 +93,9 @@ def __init__( :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to failover_strategy. """ + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + self._failure_detectors = failure_detectors self._databases = databases self._command_retry = command_retry @@ -139,19 +143,49 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): - self._check_active_database() + def callback(): + return self._active_database.client.execute_command(*args, **options) + + return self._execute_with_failure_detection(callback, args) + + def execute_pipeline(self, command_stack: tuple): + """ + Executes a stack of commands in pipeline. + """ + def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + pipe.execute_command(*command, **options) + + return pipe.execute() + + return self._execute_with_failure_detection(callback, command_stack) + + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """ + Executes a transaction block wrapped in callback. + """ + def callback(): + return self._active_database.client.transaction(transaction, *watches, **options) + + return self._execute_with_failure_detection(callback) + + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + def wrapper(): + # On each retry we need to check active database as it might change. + self._check_active_database() + return callback() return self._command_retry.call_with_retry( - lambda: self._execute_command(*args, **options), - lambda error: self._on_command_fail(error, *args), + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), ) - def _execute_command(self, *args, **options): - self._check_active_database() - return self._active_database.client.execute_command(*args, **options) - def _on_command_fail(self, error, *args): - self._event_dispatcher.dispatch(OnCommandFailEvent(args, error)) + self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error)) def _check_active_database(self): """ @@ -180,5 +214,5 @@ def _setup_event_dispatcher(self): """ event_listener = RegisterCommandFailure(self._failure_detectors) self._event_dispatcher.register_listeners({ - OnCommandFailEvent: [event_listener], + OnCommandsFailEvent: [event_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 3a5ed3ec24..315802e812 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,6 +1,6 @@ from typing import List -from redis.event import EventListenerInterface, OnCommandFailEvent +from redis.event import EventListenerInterface, OnCommandsFailEvent from redis.multidb.failure_detector import FailureDetector @@ -11,6 +11,6 @@ class RegisterCommandFailure(EventListenerInterface): def __init__(self, failure_detectors: List[FailureDetector]): self._failure_detectors = failure_detectors - def listen(self, event: OnCommandFailEvent) -> None: + def listen(self, event: OnCommandsFailEvent) -> None: for failure_detector in self._failure_detectors: - failure_detector.register_failure(event.exception, event.command) + failure_detector.register_failure(event.exception, event.commands) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index b94c4ce61e..cf3877957f 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -4,7 +4,7 @@ import pybreaker import pytest -from redis.event import EventDispatcher, OnCommandFailEvent +from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF @@ -455,8 +455,8 @@ def test_add_new_failure_detector( mock_fd = mock_multi_db_config.failure_detectors[0] # Event fired if command against mock_db1 would fail - command_fail_event = OnCommandFailEvent( - command=('SET', 'key', 'value'), + command_fail_event = OnCommandsFailEvent( + commands=('SET', 'key', 'value'), exception=Exception(), ) diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..9caad235df --- /dev/null +++ b/tests/test_multidb/test_pipeline.py @@ -0,0 +1,352 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.client import Pipeline +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.client import MultiDBClient +from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF, DEFAULT_FAILURES_THRESHOLD +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.healthcheck import EchoHealthCheck +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__enter__ = Mock(return_value=mock_pipe) + mock_pipe.__exit__ = Mock(return_value=None) + return mock_pipe + +class TestPipeline: + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set('key1', 'value1') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with client.pipeline() as pipe: + pipe.set('key1', 'value1') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + pipe = mock_pipe() + pipe.execute.return_value = ['OK', 'value'] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ['OK1', 'value'] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ['OK2', 'value'] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value'] + + sleep(0.15) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK2', 'value'] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK', 'value'] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert pipe.execute() == ['OK1', 'value'] + +class TestTransaction: + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value1'] + + for hc in mock_multi_db_config.health_checks: + assert hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object( + mock_multi_db_config, + 'databases', + return_value=databases + ): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + mock_db.client.transaction.return_value = ['OK', 'value'] + mock_db1.client.transaction.return_value = ['OK1', 'value'] + mock_db2.client.transaction.return_value = ['OK2', 'value'] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.health_checks = [ + EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) + ] + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert client.transaction(callback) == ['OK1', 'value'] + sleep(0.15) + assert client.transaction(callback) == ['OK2', 'value'] + sleep(0.1) + assert client.transaction(callback) == ['OK', 'value'] + sleep(0.1) + assert client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 93e251ed4b..a8afea4b18 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -5,6 +5,7 @@ import pytest from redis.backoff import NoBackoff +from redis.client import Pipeline from redis.exceptions import ConnectionError from redis.retry import Retry from tests.test_scenario.conftest import get_endpoint_config @@ -33,6 +34,11 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event logger.info(f"Action completed. Status: {status_result['status']}") class TestActiveActiveStandalone: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(3) + @pytest.mark.parametrize( "r_multi_db", [ @@ -47,10 +53,11 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector daemon=True, args=(fault_injector_client,event) ) - thread.start() + # Client initialized on the first command. r_multi_db.set('key', 'value') - current_active_db = r_multi_db._command_executor.active_database + current_active_db = r_multi_db.command_executor.active_database + thread.start() # Execute commands before network failure while not event.is_set(): @@ -58,7 +65,7 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector sleep(0.1) # Active db has been changed. - assert current_active_db != r_multi_db._command_executor.active_database + assert current_active_db != r_multi_db.command_executor.active_database # Execute commands after network failure for _ in range(3): @@ -68,11 +75,7 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - { - "failure_threshold": 15, - "command_retry": Retry(NoBackoff(), retries=5), - "health_check_interval": 100, - } + {"failure_threshold": 2} ], indirect=True ) @@ -81,7 +84,7 @@ def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_in thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client, event) ) thread.start() @@ -89,4 +92,141 @@ def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_in # Retries count > failure threshold, so a client gives up earlier. while not event.is_set(): assert r_multi_db.get('key') == 'value' - sleep(0.1) \ No newline at end of file + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + # Client initialized on first pipe execution. + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + # Client initialized on first pipe execution. + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + def callback(pipe: Pipeline): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + + # Client initialized on first transaction execution. + r_multi_db.transaction(callback) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.transaction(callback) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.transaction(callback) + sleep(0.1) \ No newline at end of file From 4a40ee411a31e7a7b65685604376cce4201bf957 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 13 Aug 2025 10:00:11 +0300 Subject: [PATCH 03/50] Added support for Pub/Sub mode in MultiDbClient (#3722) * Added Database, Healthcheck, CircuitBreaker, FailureDetector * Added DatabaseSelector, exceptions, refactored existing entities * Added MultiDbConfig * Added DatabaseConfig * Added DatabaseConfig test coverage * Renamed DatabaseSelector into FailoverStrategy * Added CommandExecutor * Updated healthcheck to close circuit on success * Added thread-safeness * Added missing thread-safeness * Added missing thread-safenes for dispatcher * Refactored client to keep databases in WeightedList * Added database CRUD operations * Added on-fly configuration * Added background health checks * Added background healthcheck + half-open event * Refactored background scheduling * Added support for Active-Active pipeline * Refactored healthchecks * Added Pipeline testing * Added support for transactions * Removed code repetitions, fixed weight assignment, added loops enhancement, fixed data structure * Added missing doc blocks * Added support for Pub/Sub in MultiDBClient * Refactored configuration * Refactored failure detector * Refactored retry logic * Added scenario tests * Added pybreaker optional dependency * Added pybreaker to dev dependencies * Rename tests directory * Added scenario tests for Pipeline and Transaction * Added handling of ConnectionRefusedError, added timeouts so cluster could recover * Increased timeouts * Refactored integration tests * Added scenario tests for Pub/Sub * Updated healthcheck retry * Increased timeout to avoid unprepared state before tests * Added backoff retry and changed timeouts * Added retry for healthchecks to avoid fluctuations * Changed retry configuration for healthchecks * Fixed property name * Added check for thread results --- redis/client.py | 4 +- redis/multidb/client.py | 129 +++++++++++++++++++- redis/multidb/command_executor.py | 87 ++++++++++++-- redis/multidb/event.py | 50 ++++++++ redis/multidb/healthcheck.py | 1 + tests/test_scenario/conftest.py | 29 ++++- tests/test_scenario/test_active_active.py | 138 +++++++++++++++++----- 7 files changed, 392 insertions(+), 46 deletions(-) diff --git a/redis/client.py b/redis/client.py index 060fc29493..adb57d404e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,6 +1217,7 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + pubsub = None ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1230,8 +1231,9 @@ def run_in_thread( f"Shard Channel: '{s_channel}' has no handler registered" ) + pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - self, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler ) thread.start() return thread diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 367060bcc3..172017f036 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,8 +1,9 @@ import threading import socket -from typing import List, Any, Callable +from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler +from redis.client import PubSubWorkerThread from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor @@ -201,6 +202,17 @@ def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): return self.command_executor.execute_transaction(func, *watches, *options) + def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + self.initialize() + + return PubSub(self, **kwargs) + def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. @@ -311,3 +323,118 @@ def execute(self) -> List[Any]: return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + self._client = client + self._client.command_executor.pubsub(**kwargs) + + def __enter__(self) -> "PubSub": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.reset() + + def __del__(self) -> None: + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + self.reset() + except Exception: + pass + + def reset(self) -> None: + pass + + def close(self) -> None: + self.reset() + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return self._client.command_executor.execute_pubsub_method('psubscribe', *args, **kwargs) + + def punsubscribe(self, *args): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return self._client.command_executor.execute_pubsub_method('punsubscribe', *args) + + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return self._client.command_executor.execute_pubsub_method('subscribe', *args, **kwargs) + + def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return self._client.command_executor.execute_pubsub_method('unsubscribe', *args) + + def ssubscribe(self, *args, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + return self._client.command_executor.execute_pubsub_method('ssubscribe', *args, **kwargs) + + def sunsubscribe(self, *args): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + return self._client.command_executor.execute_pubsub_method('sunsubscribe', *args) + + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + get_sharded_message = get_message + + def run_in_thread( + self, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": + return self._client.command_executor.execute_pubsub_run_in_thread( + sleep_time=sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=self + ) + diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 690ea49a5c..795ef8f8b1 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union, Optional, Callable +from typing import List, Optional, Callable -from redis.client import Pipeline +from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState -from redis.multidb.event import RegisterCommandFailure +from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -34,7 +34,7 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Union[Database, None]: + def active_database(self) -> Optional[Database]: """Returns currently active database.""" pass @@ -44,6 +44,23 @@ def active_database(self, database: AbstractDatabase) -> None: """Sets currently active database.""" pass + @abstractmethod + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + @property @abstractmethod def failover_strategy(self) -> FailoverStrategy: @@ -103,7 +120,9 @@ def __init__( self._event_dispatcher = event_dispatcher self._auto_fallback_interval = auto_fallback_interval self._next_fallback_attempt: datetime - self._active_database: Union[Database, None] = None + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() @@ -128,8 +147,22 @@ def active_database(self) -> Optional[AbstractDatabase]: @active_database.setter def active_database(self, database: AbstractDatabase) -> None: + old_active = self._active_database self._active_database = database + if old_active is not None and old_active is not database: + self._event_dispatcher.dispatch( + ActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + @property def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy @@ -143,6 +176,7 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): + """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) @@ -170,6 +204,39 @@ def callback(): return self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return self._execute_with_failure_detection(callback) + + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """ + Executes given method on active pub/sub. + """ + def callback(): + method = getattr(self.active_pubsub, method_name) + return method(*args, **kwargs) + + return self._execute_with_failure_detection(callback, *args) + + def execute_pubsub_run_in_thread( + self, + pubsub, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": + def callback(): + return self._active_pubsub.run_in_thread( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=pubsub + ) + + return self._execute_with_failure_detection(callback) + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): """ Execute a commands execution callback with failure detection. @@ -199,7 +266,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self._active_database = self._failover_strategy.database + self.active_database = self._failover_strategy.database self._schedule_next_fallback() def _schedule_next_fallback(self) -> None: @@ -210,9 +277,11 @@ def _schedule_next_fallback(self) -> None: def _setup_event_dispatcher(self): """ - Registers command failure event listener. + Registers necessary listeners. """ - event_listener = RegisterCommandFailure(self._failure_detectors) + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() self._event_dispatcher.register_listeners({ - OnCommandsFailEvent: [event_listener], + OnCommandsFailEvent: [failure_listener], + ActiveDatabaseChanged: [resubscribe_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 315802e812..2598bc4d06 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,58 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent +from redis.multidb.config import Databases +from redis.multidb.database import AbstractDatabase from redis.multidb.failure_detector import FailureDetector +class ActiveDatabaseChanged: + """ + Event fired when an active database has been changed. + """ + def __init__( + self, + old_database: AbstractDatabase, + new_database: AbstractDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AbstractDatabase: + return self._old_database + + @property + def new_database(self) -> AbstractDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): + """ + Re-subscribe currently active pub/sub to a new active database. + """ + def listen(self, event: ActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + new_pubsub.shard_channels = old_pubsub.shard_channels + new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + old_pubsub.close() class RegisterCommandFailure(EventListenerInterface): """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index a96b9cf815..1396a1e997 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -21,6 +21,7 @@ def __init__( retry: Retry, ) -> None: self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) @property def retry(self) -> Retry: diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 8ae7441e98..486dc948f1 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -3,13 +3,22 @@ import pytest -from redis.backoff import NoBackoff +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_FAILURES_THRESHOLD +from redis.multidb.event import ActiveDatabaseChanged +from redis.multidb.healthcheck import EchoHealthCheck from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient +class CheckActiveDatabaseChangedListener(EventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + def listen(self, event: ActiveDatabaseChanged): + self.is_changed_flag = True def get_endpoint_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) @@ -33,13 +42,22 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest.fixture() -def r_multi_db(request) -> MultiDBClient: +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener]: endpoint_config = get_endpoint_config('re-active-active') username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(NoBackoff(), retries=3)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.5, base=0.05), retries=3)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_checks = [EchoHealthCheck(Retry(ExponentialBackoff(cap=5, base=0.5), retries=3))] health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + ActiveDatabaseChanged: [listener], + }) db_configs = [] db_config = DatabaseConfig( @@ -64,12 +82,13 @@ def r_multi_db(request) -> MultiDBClient: ) db_configs.append(db_config1) - config = MultiDbConfig( databases_config=db_configs, + health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, ) - return MultiDBClient(config) \ No newline at end of file + return MultiDBClient(config), listener \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index a8afea4b18..071babb6c0 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,13 +1,11 @@ +import json import logging import threading from time import sleep import pytest -from redis.backoff import NoBackoff from redis.client import Pipeline -from redis.exceptions import ConnectionError -from redis.retry import Retry from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType @@ -17,7 +15,7 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event endpoint_config = get_endpoint_config('re-active-active') action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 3, "cluster_index": 0} + parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 2, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -54,9 +52,10 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on the first command. r_multi_db.set('key', 'value') - current_active_db = r_multi_db.command_executor.active_database thread.start() # Execute commands before network failure @@ -64,35 +63,12 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert r_multi_db.get('key') == 'value' sleep(0.1) - # Active db has been changed. - assert current_active_db != r_multi_db.command_executor.active_database - # Execute commands after network failure for _ in range(3): assert r_multi_db.get('key') == 'value' sleep(0.1) - @pytest.mark.parametrize( - "r_multi_db", - [ - {"failure_threshold": 2} - ], - indirect=True - ) - def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_injector_client): - event = threading.Event() - thread = threading.Thread( - target=trigger_network_failure_action, - daemon=True, - args=(fault_injector_client, event) - ) - thread.start() - - with pytest.raises(ConnectionError): - # Retries count > failure threshold, so a client gives up earlier. - while not event.is_set(): - assert r_multi_db.get('key') == 'value' - sleep(0.1) + assert listener.is_changed_flag == True @pytest.mark.parametrize( "r_multi_db", @@ -109,6 +85,8 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on first pipe execution. with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') @@ -145,6 +123,8 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ @@ -160,6 +140,8 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on first pipe execution. pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') @@ -194,6 +176,8 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ @@ -209,6 +193,8 @@ def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_cli args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -229,4 +215,96 @@ def callback(pipe: Pipeline): # Execute pipeline after network failure for _ in range(3): r_multi_db.transaction(callback) - sleep(0.1) \ No newline at end of file + sleep(0.1) + + assert listener.is_changed_flag == True + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + r_multi_db, listener = r_multi_db + data = json.dumps({'message': 'test'}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.subscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.publish('test-channel', data) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.publish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + + assert listener.is_changed_flag == True + assert messages_count > 5 + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + r_multi_db, listener = r_multi_db + data = json.dumps({'message': 'test'}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.ssubscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + + assert listener.is_changed_flag == True + assert messages_count > 5 \ No newline at end of file From fb500f6c87954a7ed14f9637dfbe7b6b62635d5b Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 18 Aug 2025 15:13:03 +0300 Subject: [PATCH 04/50] Refactored docblocks (#3744) --- redis/multidb/circuit.py | 20 ++++++++++++++++++-- redis/multidb/client.py | 12 ++++++++++-- redis/multidb/command_executor.py | 18 ++++++++++++------ redis/multidb/database.py | 11 +++++++---- redis/multidb/event.py | 2 +- redis/multidb/failover.py | 2 +- redis/multidb/failure_detector.py | 12 +++++++++--- 7 files changed, 58 insertions(+), 19 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 9211173c83..79c8a5f379 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -51,12 +51,20 @@ def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]) pass class PBListener(pybreaker.CircuitBreakerListener): + """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, cb: Callable[[CircuitBreaker, State, State], None], database, ): - """Wrapper for callback to be compatible with pybreaker implementation.""" + """ + Initialize a PBListener instance. + + Args: + cb: Callback function that will be called when the circuit breaker state changes. + database: Database instance associated with this circuit breaker. + """ + self._cb = cb self._database = database @@ -70,7 +78,15 @@ def state_change(self, cb, old_state, new_state): class PBCircuitBreakerAdapter(CircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): - """Adapter for pybreaker CircuitBreaker.""" + """ + Initialize a PBCircuitBreakerAdapter instance. + + This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible + with our CircuitBreaker interface. + + Args: + cb: A pybreaker CircuitBreaker instance to be adapted. + """ self._cb = cb self._state_pb_mapper = { State.CLOSED: self._cb.close, diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 172017f036..1073ea8168 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -247,7 +247,6 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): Runs health checks as a recurring task. Runs health checks against all databases. """ - for database, _ in self._databases: self._check_db_health(database, on_error) @@ -313,9 +312,11 @@ def pipeline_execute_command(self, *args, **options) -> "Pipeline": return self def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" return self.pipeline_execute_command(*args, **kwargs) def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" if not self._client.initialized: self._client.initialize() @@ -326,9 +327,16 @@ def execute(self) -> List[Any]: class PubSub: """ - PubSub object for multi database client. + PubSub object for multi-database client. """ def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + self._client = client self._client.command_executor.pubsub(**kwargs) diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 795ef8f8b1..40370c2e18 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -103,12 +103,15 @@ def __init__( auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ - :param failure_detectors: List of failure detectors. - :param databases: List of databases. - :param failover_strategy: Strategy that defines the failover logic. - :param event_dispatcher: Event dispatcher. - :param auto_fallback_interval: Interval between fallback attempts. Fallback to a new database according to - failover_strategy. + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ for fd in failure_detectors: fd.set_command_executor(command_executor=self) @@ -205,6 +208,9 @@ def callback(): return self._execute_with_failure_detection(callback) def pubsub(self, **kwargs): + """ + Initializes a PubSub object on a currently active database. + """ def callback(): if self._active_pubsub is None: self._active_pubsub = self._active_database.client.pubsub(**kwargs) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 15db52e909..204b7c91f3 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -74,10 +74,13 @@ def __init__( state: State = State.DISCONNECTED, ): """ - param: client: Client instance for communication with the database. - param: circuit: Circuit breaker for the current database. - param: weight: Weight of current database. Database with the highest weight becomes Active. - param: state: State of the current database. + Initialize a new Database instance. + + Args: + client: Underlying Redis client instance for database operations + circuit: Circuit breaker for handling database failures + weight: Weight value used for database failover prioritization + state: Initial database state, defaults to DISCONNECTED """ self._client = client self._cb = circuit diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 2598bc4d06..7b16d4ba88 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -39,7 +39,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe currently active pub/sub to a new active database. + Re-subscribe the currently active pub / sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index a4c825aac1..541f3413dc 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -23,7 +23,7 @@ def set_databases(self, databases: Databases) -> None: class WeightBasedFailoverStrategy(FailoverStrategy): """ - Choose the active database with the highest weight. + Failover strategy based on database weights. """ def __init__( self, diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 50f1c839bd..3280fa6c32 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -32,9 +32,15 @@ def __init__( error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ - :param threshold: Threshold of failed commands over the duration after which database will be marked as failed. - :param duration: Interval in seconds after which database will be marked as failed if threshold was exceeded. - :param error_types: List of exception that has to be registered. By default, all exceptions are registered. + Initialize a new CommandFailureDetector instance. + + Args: + threshold: The number of failures that must occur within the duration to trigger failure detection. + duration: The time window in seconds during which failures are counted. + error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted. + + The detector tracks command failures within a sliding time window. When the number of failures + exceeds the threshold within the specified duration, it triggers failure detection. """ self._command_executor = None self._threshold = threshold From d6cdaeb3f49ae78402f4a4fd8f6ebf58a698c0b1 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:10:47 +0300 Subject: [PATCH 05/50] Refactored healthcheck and failure detector to extend default one (#3747) --- redis/multidb/client.py | 13 +- redis/multidb/config.py | 4 +- tests/test_multidb/conftest.py | 1 - tests/test_multidb/test_client.py | 195 +++++++++------------------- tests/test_multidb/test_pipeline.py | 94 +++++--------- tests/test_scenario/conftest.py | 4 +- 6 files changed, 104 insertions(+), 207 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 1073ea8168..8183d11293 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -22,10 +22,17 @@ class MultiDBClient(RedisModuleCommands, CoreCommands): """ def __init__(self, config: MultiDbConfig): self._databases = config.databases() - self._health_checks = config.default_health_checks() if config.health_checks is None else config.health_checks + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + self._health_check_interval = config.health_check_interval - self._failure_detectors = config.default_failure_detectors() \ - if config.failure_detectors is None else config.failure_detectors + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + self._failover_strategy = config.default_failover_strategy() \ if config.failover_strategy is None else config.failover_strategy self._failover_strategy.set_databases(self._databases) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 64ad7c9052..4bacc2c680 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -51,10 +51,10 @@ class MultiDbConfig: databases_config: A list of database configurations. client_class: The client class used to manage database connections. command_retry: Retry strategy for executing database commands. - failure_detectors: Optional list of failure detectors for monitoring database failures. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. failure_threshold: Threshold for determining database failure. failures_interval: Time interval for tracking database failures. - health_checks: Optional list of health checks performed on databases. + health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. health_check_retries: Number of retry attempts for performing health checks. health_check_backoff: Backoff strategy for health check retries. diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index ad2057a118..f85e0a6fd7 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -94,7 +94,6 @@ def mock_multi_db_config( config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], - health_checks=[mock_hc], health_check_interval=hc_interval, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index cf3877957f..c14f605c2a 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -32,26 +32,20 @@ class TestMultiDbClient: indirect=True, ) def test_execute_command_against_correct_db_on_successful_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -70,26 +64,20 @@ def test_execute_command_against_correct_db_on_successful_initialization( indirect=True, ) def test_execute_command_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -124,20 +112,14 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -168,21 +150,15 @@ def test_execute_command_auto_fallback_to_highest_weight_db( ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -223,17 +199,13 @@ def test_execute_command_auto_fallback_to_highest_weight_db( indirect=True, ) def test_execute_command_throws_exception_on_failed_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -241,8 +213,7 @@ def test_execute_command_throws_exception_on_failed_initialization( with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): client.set('key', 'value') - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.DISCONNECTED assert mock_db1.state == DBState.DISCONNECTED @@ -261,26 +232,20 @@ def test_execute_command_throws_exception_on_failed_initialization( indirect=True, ) def test_add_database_throws_exception_on_same_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 with pytest.raises(ValueError, match='Given database already exists'): client.add_database(mock_db) - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -295,36 +260,28 @@ def test_add_database_throws_exception_on_same_database( indirect=True, ) def test_add_database_makes_new_database_active( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK2' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 2 assert mock_db.state == DBState.PASSIVE assert mock_db2.state == DBState.ACTIVE client.add_database(mock_db1) - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert client.set('key', 'value') == 'OK1' @@ -345,28 +302,22 @@ def test_add_database_makes_new_database_active( indirect=True, ) def test_remove_highest_weighted_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -392,28 +343,22 @@ def test_remove_highest_weighted_database( indirect=True, ) def test_update_database_weight_to_be_highest( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -441,15 +386,12 @@ def test_update_database_weight_to_be_highest( indirect=True, ) def test_add_new_failure_detector( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_multi_db_config.event_dispatcher = EventDispatcher() mock_fd = mock_multi_db_config.failure_detectors[0] @@ -460,15 +402,12 @@ def test_add_new_failure_detector( exception=Exception(), ) - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -499,26 +438,20 @@ def test_add_new_failure_detector( indirect=True, ) def test_add_new_health_check( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -526,6 +459,7 @@ def test_add_new_health_check( client.add_health_check(another_hc) client._check_db_health(mock_db1) + assert mock_hc.check_health.call_count == 4 assert another_hc.check_health.call_count == 1 @pytest.mark.parametrize( @@ -541,27 +475,21 @@ def test_add_new_health_check( indirect=True, ) def test_set_active_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db.client.execute_command.return_value = 'OK' - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -577,8 +505,7 @@ def test_set_active_database( with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + mock_hc.check_health.return_value = False with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 9caad235df..f0d2a0dbe3 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -36,21 +36,17 @@ class TestPipeline: indirect=True, ) def test_executes_pipeline_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): pipe = mock_pipe() pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -60,9 +56,7 @@ def test_executes_pipeline_against_correct_db( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -77,21 +71,17 @@ def test_executes_pipeline_against_correct_db( indirect=True, ) def test_execute_pipeline_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): pipe = mock_pipe() pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -101,9 +91,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -122,7 +110,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( indirect=True, ) def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -138,11 +126,10 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -160,11 +147,6 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -216,19 +198,15 @@ class TestTransaction: indirect=True, ) def test_executes_transaction_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -238,9 +216,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -255,19 +231,15 @@ def callback(pipe: Pipeline): indirect=True, ) def test_execute_transaction_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -277,9 +249,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -314,11 +284,10 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -328,11 +297,6 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_db2.client.transaction.return_value = ['OK2', 'value'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 486dc948f1..b347fe50ba 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -51,7 +51,6 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. - health_checks = [EchoHealthCheck(Retry(ExponentialBackoff(cap=5, base=0.5), retries=3))] health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() @@ -84,10 +83,11 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen config = MultiDbConfig( databases_config=db_configs, - health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, + health_check_backoff=ExponentialBackoff(cap=0.5, base=0.05), + health_check_retries=3, event_dispatcher=event_dispatcher, ) From 68fe530f3a6cf4e237b1564f7888f18f586d87ac Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:21:21 +0300 Subject: [PATCH 06/50] Added MultiDbClient support with OSS Cluster API (#3734) * Added Database, Healthcheck, CircuitBreaker, FailureDetector * Added DatabaseSelector, exceptions, refactored existing entities * Added MultiDbConfig * Added DatabaseConfig * Added DatabaseConfig test coverage * Renamed DatabaseSelector into FailoverStrategy * Added CommandExecutor * Updated healthcheck to close circuit on success * Added thread-safeness * Added missing thread-safeness * Added missing thread-safenes for dispatcher * Refactored client to keep databases in WeightedList * Added database CRUD operations * Added on-fly configuration * Added background health checks * Added background healthcheck + half-open event * Refactored background scheduling * Added support for Active-Active pipeline * Refactored healthchecks * Added Pipeline testing * Added support for transactions * Removed code repetitions, fixed weight assignment, added loops enhancement, fixed data structure * Added missing doc blocks * Added support for Pub/Sub in MultiDBClient * Refactored configuration * Refactored failure detector * Refactored retry logic * Added scenario tests * Added pybreaker optional dependency * Added pybreaker to dev dependencies * Rename tests directory * Added scenario tests for Pipeline and Transaction * Added handling of ConnectionRefusedError, added timeouts so cluster could recover * Increased timeouts * Refactored integration tests * Added scenario tests for Pub/Sub * Updated healthcheck retry * Increased timeout to avoid unprepared state before tests * Added backoff retry and changed timeouts * Added retry for healthchecks to avoid fluctuations * Changed retry configuration for healthchecks * Fixed property name * Added check for thread results * Added MultiDbClient support with OSS Cluster API * Removed database statuses * Increased test timeouts * Increased retry timeout * Increased timeout retries * Updated base threshold for retries * Fixed flacky tests * Added missing docblocks --- redis/client.py | 12 +- redis/cluster.py | 3 +- redis/multidb/client.py | 35 ++--- redis/multidb/command_executor.py | 7 +- redis/multidb/database.py | 30 +---- redis/multidb/event.py | 2 +- redis/multidb/healthcheck.py | 19 ++- tests/test_multidb/conftest.py | 5 +- tests/test_multidb/test_client.py | 48 +------ tests/test_multidb/test_healthcheck.py | 8 +- tests/test_scenario/conftest.py | 20 ++- tests/test_scenario/test_active_active.py | 151 ++++++++++++---------- 12 files changed, 157 insertions(+), 183 deletions(-) diff --git a/redis/client.py b/redis/client.py index adb57d404e..e22ca3d73d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,7 +1217,8 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, - pubsub = None + pubsub = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1233,7 +1234,7 @@ def run_in_thread( pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler, sharded_pubsub=sharded_pubsub ) thread.start() return thread @@ -1248,12 +1249,14 @@ def __init__( exception_handler: Union[ Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None ] = None, + sharded_pubsub: bool = False, ): super().__init__() self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self.exception_handler = exception_handler + self.sharded_pubsub = sharded_pubsub self._running = threading.Event() def run(self) -> None: @@ -1264,7 +1267,10 @@ def run(self) -> None: sleep_time = self.sleep_time while self._running.is_set(): try: - pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + if not self.sharded_pubsub: + pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + else: + pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=sleep_time) except BaseException as e: if self.exception_handler is None: raise diff --git a/redis/cluster.py b/redis/cluster.py index dbcf5cc2b7..2fd4625e6b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3154,7 +3154,8 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._nodes_manager.update_moved_exception(error) + if type(error) == MovedError: + self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8183d11293..2f87024f20 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -9,7 +9,7 @@ from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases +from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -78,13 +78,8 @@ def raise_exception_on_failed_hc(error): # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db_found: - database.state = DBState.ACTIVE self.command_executor.active_database = database is_active_db_found = True - elif database.circuit.state == CBState.CLOSED and is_active_db_found: - database.state = DBState.PASSIVE - else: - database.state = DBState.DISCONNECTED if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') @@ -115,8 +110,6 @@ def set_active_database(self, database: AbstractDatabase) -> None: if database.circuit.state == CBState.CLOSED: highest_weighted_db, _ = self._databases.get_top_n(1)[0] - highest_weighted_db.state = DBState.PASSIVE - database.state = DBState.ACTIVE self.command_executor.active_database = database return @@ -138,9 +131,7 @@ def add_database(self, database: AbstractDatabase): def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: - new_database.state = DBState.ACTIVE self.command_executor.active_database = new_database - highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): """ @@ -150,7 +141,6 @@ def remove_database(self, database: Database): highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: - highest_weighted_db.state = DBState.ACTIVE self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: AbstractDatabase, weight: float): @@ -240,7 +230,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: + except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError, ValueError) as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False @@ -334,7 +324,7 @@ def execute(self) -> List[Any]: class PubSub: """ - PubSub object for multi-database client. + PubSub object for multi database client. """ def __init__(self, client: MultiDBClient, **kwargs): """Initialize the PubSub object for a multi-database client. @@ -438,18 +428,33 @@ def get_message( ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout ) - get_sharded_message = get_message + def get_sharded_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available in a sharded channel, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_sharded_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) def run_in_thread( self, sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": return self._client.command_executor.execute_pubsub_run_in_thread( sleep_time=sleep_time, daemon=daemon, exception_handler=exception_handler, - pubsub=self + pubsub=self, + sharded_pubsub=sharded_pubsub, ) diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 40370c2e18..094230a31d 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -235,10 +235,15 @@ def execute_pubsub_run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": def callback(): return self._active_pubsub.run_in_thread( - sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=pubsub + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=pubsub, + sharded_pubsub=sharded_pubsub ) return self._execute_with_failure_detection(callback) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 204b7c91f3..3253ffa093 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -8,12 +8,6 @@ from redis.multidb.circuit import CircuitBreaker from redis.typing import Number - -class State(Enum): - ACTIVE = 0 - PASSIVE = 1 - DISCONNECTED = 2 - class AbstractDatabase(ABC): @property @abstractmethod @@ -39,18 +33,6 @@ def weight(self, weight: float): """Set the weight of this database in compare to others.""" pass - @property - @abstractmethod - def state(self) -> State: - """The state of the current database.""" - pass - - @state.setter - @abstractmethod - def state(self, state: State): - """Set the state of the current database.""" - pass - @property @abstractmethod def circuit(self) -> CircuitBreaker: @@ -70,8 +52,7 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float, - state: State = State.DISCONNECTED, + weight: float ): """ Initialize a new Database instance. @@ -86,7 +67,6 @@ def __init__( self._cb = circuit self._cb.database = self self._weight = weight - self._state = state @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -104,14 +84,6 @@ def weight(self) -> float: def weight(self, weight: float): self._weight = weight - @property - def state(self) -> State: - return self._state - - @state.setter - def state(self, state: State): - self._state = state - @property def circuit(self) -> CircuitBreaker: return self._cb diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 7b16d4ba88..2598bc4d06 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -39,7 +39,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe the currently active pub / sub to a new active database. + Re-subscribe currently active pub/sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 1396a1e997..cca220dc3f 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,4 +1,7 @@ from abc import abstractmethod, ABC + +import redis +from redis import Redis from redis.retry import Retry @@ -51,8 +54,20 @@ def check_health(self, database) -> bool: def _returns_echoed_message(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] - actual_message = database.client.execute_command('ECHO', "healthcheck") - return actual_message in expected_message + + if isinstance(database.client, Redis): + actual_message = database.client.execute_command("ECHO" ,"healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = node.redis_connection.execute_command("ECHO" ,"healthcheck") + + if actual_message not in expected_message: + return False + + return True def _dummy_fail(self): pass \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index f85e0a6fd7..a34ef01476 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -7,7 +7,7 @@ from redis.multidb.circuit import CircuitBreaker, State as CBState from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, State, Databases +from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -38,7 +38,6 @@ def mock_hc() -> HealthCheck: def mock_db(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -53,7 +52,6 @@ def mock_db(request) -> Database: def mock_db1(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -68,7 +66,6 @@ def mock_db1(request) -> Database: def mock_db2(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c14f605c2a..37ee9b3fd3 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import State as DBState, AbstractDatabase +from redis.multidb.database import AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -166,26 +166,14 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.15) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.22) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -215,10 +203,6 @@ def test_execute_command_throws_exception_on_failed_initialization( assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.DISCONNECTED - assert mock_db1.state == DBState.DISCONNECTED - assert mock_db2.state == DBState.DISCONNECTED - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -277,18 +261,11 @@ def test_add_database_makes_new_database_active( assert client.set('key', 'value') == 'OK2' assert mock_hc.check_health.call_count == 2 - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - client.add_database(mock_db1) assert mock_hc.check_health.call_count == 3 assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -319,17 +296,10 @@ def test_remove_highest_weighted_database( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.remove_database(mock_db1) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -360,19 +330,11 @@ def test_update_database_weight_to_be_highest( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -491,17 +453,9 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.set_active_database(mock_db) assert client.set('key', 'value') == 'OK' - assert mock_db.state == DBState.ACTIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.PASSIVE - with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 9601638913..08bd8ab0c4 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,5 +1,5 @@ from redis.backoff import ExponentialBackoff -from redis.multidb.database import Database, State +from redis.multidb.database import Database from redis.multidb.healthcheck import EchoHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -14,7 +14,7 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 @@ -26,7 +26,7 @@ def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, moc """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 3 @@ -35,7 +35,7 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] mock_cb.state = CBState.HALF_OPEN hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index b347fe50ba..4182962fb1 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -3,6 +3,7 @@ import pytest +from redis import Redis from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient @@ -42,12 +43,18 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener]: - endpoint_config = get_endpoint_config('re-active-active') +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.5, base=0.05), retries=3)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. @@ -82,13 +89,14 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen db_configs.append(db_config1) config = MultiDbConfig( + client_class=client_class, databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, - health_check_backoff=ExponentialBackoff(cap=0.5, base=0.05), - health_check_retries=3, event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + health_check_retries=3, ) - return MultiDBClient(config), listener \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 071babb6c0..967fa43cdb 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -5,17 +5,16 @@ import pytest +from redis import Redis, RedisCluster from redis.client import Pipeline -from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) -def trigger_network_failure_action(fault_injector_client, event: threading.Event = None): - endpoint_config = get_endpoint_config('re-active-active') +def trigger_network_failure_action(fault_injector_client, config, event: threading.Event = None): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 2, "cluster_index": 0} + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -31,29 +30,32 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event logger.info(f"Action completed. Status: {status_result['status']}") -class TestActiveActiveStandalone: +class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(3) + sleep(5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on the first command. r_multi_db.set('key', 'value') thread.start() @@ -61,32 +63,33 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector # Execute commands before network failure while not event.is_set(): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) - # Execute commands after network failure - for _ in range(3): + # Execute commands until database failover + while not listener.is_changed_flag: assert r_multi_db.get('key') == 'value' - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on first pipe execution. with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') @@ -109,10 +112,10 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + for _ in range(5): with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -121,27 +124,28 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on first pipe execution. pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') @@ -156,6 +160,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject # Execute pipeline before network failure while not event.is_set(): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -163,10 +168,11 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + for _ in range(5): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -174,27 +180,28 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -207,34 +214,35 @@ def callback(pipe: Pipeline): r_multi_db.transaction(callback) thread.start() - # Execute pipeline before network failure + # Execute transaction before network failure while not event.is_set(): r_multi_db.transaction(callback) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute transaction until database failover + while not listener.is_changed_flag: r_multi_db.transaction(callback) - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - - r_multi_db, listener = r_multi_db data = json.dumps({'message': 'test'}) messages_count = 0 @@ -249,37 +257,38 @@ def handler(message): pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) thread.start() - # Execute pipeline before network failure + # Execute publish before network failure while not event.is_set(): r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute publish until database failover + while not listener.is_changed_flag: r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() - - assert listener.is_changed_flag == True assert messages_count > 5 @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - - r_multi_db, listener = r_multi_db data = json.dumps({'message': 'test'}) messages_count = 0 @@ -291,20 +300,22 @@ def handler(message): # Assign a handler and run in a separate thread. pubsub.ssubscribe(**{'test-channel': handler}) - pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + pubsub_thread = pubsub.run_in_thread( + sleep_time=0.1, + daemon=True, + sharded_pubsub=True + ) thread.start() - # Execute pipeline before network failure + # Execute publish before network failure while not event.is_set(): r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute publish until database failover + while not listener.is_changed_flag: r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() - - assert listener.is_changed_flag == True assert messages_count > 5 \ No newline at end of file From 8daa53169349b418329a6e1512002e1826dbc3af Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:57:21 +0300 Subject: [PATCH 07/50] Added LagAwareHealthCheck for MultiDBClient (#3737) * Added LagAwareHealthcheck * Added testing for LagAwareHealthCheck * Fixed timeouts * Added lag tollerance parameter * Decreased messages_count due to increased timeouts * Added docblocks * Added missing type hints * Fixed url * Refactored tests, URL and cluster support * Use primary node to send an API request * Added comment about RE bug * Moved None type to the beginning * Added health_check_url property to Database class --- redis/cluster.py | 1 + redis/http/__init__.py | 0 redis/http/http_client.py | 412 ++++++++++++++++++++++ redis/multidb/client.py | 2 +- redis/multidb/config.py | 36 +- redis/multidb/database.py | 28 +- redis/multidb/failover.py | 6 +- redis/multidb/healthcheck.py | 111 +++++- redis/retry.py | 5 +- redis/utils.py | 6 + tests/test_http/__init__.py | 0 tests/test_http/test_http_client.py | 324 +++++++++++++++++ tests/test_multidb/test_client.py | 5 +- tests/test_multidb/test_healthcheck.py | 141 +++++++- tests/test_multidb/test_pipeline.py | 8 +- tests/test_scenario/conftest.py | 31 +- tests/test_scenario/test_active_active.py | 48 ++- 17 files changed, 1130 insertions(+), 34 deletions(-) create mode 100644 redis/http/__init__.py create mode 100644 redis/http/http_client.py create mode 100644 tests/test_http/__init__.py create mode 100644 tests/test_http/test_http_client.py diff --git a/redis/cluster.py b/redis/cluster.py index 2fd4625e6b..dc91209ed2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -691,6 +691,7 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, diff --git a/redis/http/__init__.py b/redis/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/http/http_client.py b/redis/http/http_client.py new file mode 100644 index 0000000000..0a2de2e44c --- /dev/null +++ b/redis/http/http_client.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import base64 +import json +import ssl +import gzip +import zlib +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Tuple, Union +from urllib.parse import urlencode, urljoin +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError + + +__all__ = [ + "HttpClient", + "HttpResponse", + "HttpError", + "DEFAULT_TIMEOUT" +] + +from redis.backoff import ExponentialWithJitterBackoff +from redis.retry import Retry +from redis.utils import dummy_fail + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + + +@dataclass +class HttpResponse: + status: int + headers: Dict[str, str] + url: str + content: bytes + + def text(self, encoding: Optional[str] = None) -> str: + enc = encoding or self._get_encoding() + return self.content.decode(enc, errors="replace") + + def json(self) -> Any: + return json.loads(self.text(encoding=self._get_encoding())) + + def _get_encoding(self) -> str: + # Try to infer encoding from headers; default to utf-8 + ctype = self.headers.get("content-type", "") + # Example: application/json; charset=utf-8 + for part in ctype.split(";"): + p = part.strip() + if p.lower().startswith("charset="): + return p.split("=", 1)[1].strip() or "utf-8" + return "utf-8" + + +class HttpError(Exception): + def __init__(self, status: int, url: str, message: Optional[str] = None): + self.status = status + self.url = url + self.message = message or f"HTTP {status} for {url}" + super().__init__(self.message) + + +class HttpClient: + """ + A lightweight HTTP client for REST API calls. + """ + def __init__( + self, + base_url: str = "", + *, + headers: Optional[Mapping[str, str]] = None, + timeout: float = DEFAULT_TIMEOUT, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + auth_basic: Optional[Tuple[str, str]] = None, # (username, password) + user_agent: str = DEFAULT_USER_AGENT, + ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + base_url: Base URL for all requests. Will be prefixed to all paths. + headers: Default headers to include in all requests. + timeout: Default timeout in seconds for requests. + retry: Retry configuration for failed requests. + verify_tls: Whether to verify TLS certificates. + ca_file: Path to CA certificate file for TLS verification. + ca_path: Path to a directory containing CA certificates. + ca_data: CA certificate data as string or bytes. + client_cert_file: Path to client certificate for mutual TLS. + client_key_file: Path to a client private key for mutual TLS. + client_key_password: Password for an encrypted client private key. + auth_basic: Tuple of (username, password) for HTTP basic auth. + user_agent: User-Agent header value for requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.base_url = base_url.rstrip() + "/" if base_url and not base_url.endswith("/") else base_url + self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} + self.timeout = timeout + self.retry = retry + self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError)) + self.verify_tls = verify_tls + + # TLS settings + self.ca_file = ca_file + self.ca_path = ca_path + self.ca_data = ca_data + self.client_cert_file = client_cert_file + self.client_key_file = client_key_file + self.client_key_password = client_key_password + + self.auth_basic = auth_basic + self.user_agent = user_agent + + # Public JSON-centric helpers + def get( + self, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "GET", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def delete( + self, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "DELETE", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def post( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "POST", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def put( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PUT", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def patch( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PATCH", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + # Low-level request + def request( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + url = self._build_url(path, params) + all_headers = self._prepare_headers(headers, body) + data = body.encode("utf-8") if isinstance(body, str) else body + + req = Request(url=url, method=method.upper(), data=data, headers=all_headers) + + context: Optional[ssl.SSLContext] = None + if url.lower().startswith("https"): + if self.verify_tls: + # Use provided CA material if any; fall back to system defaults + context = ssl.create_default_context( + cafile=self.ca_file, + capath=self.ca_path, + cadata=self.ca_data, + ) + # Load client certificate for mTLS if configured + if self.client_cert_file: + context.load_cert_chain( + certfile=self.client_cert_file, + keyfile=self.client_key_file, + password=self.client_key_password, + ) + else: + # Verification disabled + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + try: + return self.retry.call_with_retry( + lambda: self._make_request(req, context=context, timeout=timeout), + lambda _: dummy_fail(), + lambda error: self._is_retryable_http_error(error), + ) + except HTTPError as e: + # Read error body, build response, and decide on retry + err_body = b"" + try: + err_body = e.read() + except Exception: + pass + headers_map = {k.lower(): v for k, v in (e.headers or {}).items()} + err_body = self._maybe_decompress(err_body, headers_map) + status = getattr(e, "code", 0) or 0 + response = HttpResponse( + status=status, + headers=headers_map, + url=url, + content=err_body, + ) + return response + + def _make_request( + self, + request: Request, + context: Optional[ssl.SSLContext] = None, + timeout: Optional[float] = None, + ): + with urlopen(request, timeout=timeout or self.timeout, context=context) as resp: + raw = resp.read() + headers_map = {k.lower(): v for k, v in resp.headers.items()} + raw = self._maybe_decompress(raw, headers_map) + return HttpResponse( + status=resp.status, + headers=headers_map, + url=resp.geturl(), + content=raw, + ) + + def _is_retryable_http_error(self, error: Exception) -> bool: + if isinstance(error, HTTPError): + return self._should_retry_status(error.code) + return False + + # Internal utilities + def _json_call( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + body: Optional[Union[bytes, str]] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + resp = self.request( + method=method, + path=path, + params=params, + headers=headers, + body=body, + timeout=timeout, + ) + if not (200 <= resp.status < 400): + raise HttpError(resp.status, resp.url, resp.text()) + if expect_json: + return resp.json() + return resp + + def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + if json_body is not None and data is not None: + raise ValueError("Provide either json_body or data, not both.") + if json_body is not None: + return json.dumps(json_body, ensure_ascii=False, separators=(",", ":")) + return data + + def _build_url( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + ) -> str: + url = urljoin(self.base_url or "", path) + if params: + # urlencode with doseq=True supports list/tuple values + query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) + separator = "&" if ("?" in url) else "?" + url = f"{url}{separator}{query}" if query else url + return url + + def _prepare_headers(self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]) -> Dict[str, str]: + # Start with defaults + prepared: Dict[str, str] = {} + prepared.update(self._default_headers) + + # Standard defaults for JSON REST usage + prepared.setdefault("accept", "application/json") + prepared.setdefault("user-agent", self.user_agent) + # We will send gzip accept-encoding; handle decompression manually + prepared.setdefault("accept-encoding", "gzip, deflate") + + # If we have a string body and content-type not specified, assume JSON + if body is not None and isinstance(body, str): + prepared.setdefault("content-type", "application/json; charset=utf-8") + + # Basic authentication if provided and not overridden + if self.auth_basic and "authorization" not in prepared: + user, pwd = self.auth_basic + token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii") + prepared["authorization"] = f"Basic {token}" + + # Merge per-call headers (case-insensitive) + if headers: + for k, v in headers.items(): + prepared[k.lower()] = v + + # urllib expects header keys in canonical capitalization sometimes; but it’s tolerant. + # We'll return as provided; urllib will handle it. + return prepared + + def _should_retry_status(self, status: int) -> bool: + return status in RETRY_STATUS_CODES + + def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes: + if not content: + return content + encoding = (headers.get("content-encoding") or "").lower() + try: + if "gzip" in encoding: + return gzip.decompress(content) + if "deflate" in encoding: + # Try raw deflate, then zlib-wrapped + try: + return zlib.decompress(content, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(content) + except Exception: + # If decompression fails, return original bytes + return content + return content \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 2f87024f20..56342a7a53 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -230,7 +230,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError, ValueError) as e: + except Exception as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 4bacc2c680..5555baec44 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -12,14 +12,13 @@ from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry DEFAULT_GRACE_PERIOD = 5.0 DEFAULT_HEALTH_CHECK_INTERVAL = 5 -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) DEFAULT_FAILURES_THRESHOLD = 3 DEFAULT_FAILURES_DURATION = 2 DEFAULT_FAILOVER_RETRIES = 3 @@ -31,12 +30,36 @@ def default_event_dispatcher() -> EventDispatcherInterface: @dataclass class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ weight: float = 1.0 client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) @@ -118,7 +141,12 @@ def databases(self) -> Databases: circuit = database_config.default_circuit_breaker() \ if database_config.circuit is None else database_config.circuit databases.add( - Database(client=client, circuit=circuit, weight=database_config.weight), + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), database_config.weight ) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 3253ffa093..b03e77bd70 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -1,7 +1,7 @@ import redis from abc import ABC, abstractmethod from enum import Enum -from typing import Union +from typing import Union, Optional from redis import RedisCluster from redis.data_structure import WeightedList @@ -45,6 +45,18 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass + @property + @abstractmethod + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" + pass + + @health_check_url.setter + @abstractmethod + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" + pass + Databases = WeightedList[tuple[AbstractDatabase, Number]] class Database(AbstractDatabase): @@ -52,7 +64,8 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float + weight: float, + health_check_url: Optional[str] = None, ): """ Initialize a new Database instance. @@ -61,12 +74,13 @@ def __init__( client: Underlying Redis client instance for database operations circuit: Circuit breaker for handling database failures weight: Weight value used for database failover prioritization - state: Initial database state, defaults to DISCONNECTED + health_check_url: Health check URL associated with the current database """ self._client = client self._cb = circuit self._cb.database = self self._weight = weight + self._health_check_url = health_check_url @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,3 +105,11 @@ def circuit(self) -> CircuitBreaker: @circuit.setter def circuit(self, circuit: CircuitBreaker): self._cb = circuit + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 541f3413dc..d6cf198678 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -6,6 +6,7 @@ from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry +from redis.utils import dummy_fail class FailoverStrategy(ABC): @@ -37,7 +38,7 @@ def __init__( def database(self) -> AbstractDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def set_databases(self, databases: Databases) -> None: @@ -49,6 +50,3 @@ def _get_active_database(self) -> AbstractDatabase: return database raise NoValidDatabaseException('No valid database available for communication') - - def _dummy_fail(self): - pass diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index cca220dc3f..63ba415334 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,9 +1,17 @@ +import logging from abc import abstractmethod, ABC +from typing import Optional, Tuple, Union -import redis from redis import Redis +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient from redis.retry import Retry +from redis.utils import dummy_fail +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) class HealthCheck(ABC): @@ -21,7 +29,7 @@ def check_health(self, database) -> bool: class AbstractHealthCheck(HealthCheck): def __init__( self, - retry: Retry, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: self._retry = retry self._retry.update_supported_errors([ConnectionRefusedError]) @@ -37,8 +45,8 @@ def check_health(self, database) -> bool: class EchoHealthCheck(AbstractHealthCheck): def __init__( - self, - retry: Retry, + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: """ Check database healthiness by sending an echo request. @@ -49,7 +57,7 @@ def __init__( def check_health(self, database) -> bool: return self._retry.call_with_retry( lambda: self._returns_echoed_message(database), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def _returns_echoed_message(self, database) -> bool: @@ -69,5 +77,94 @@ def _returns_echoed_message(self, database) -> bool: return True - def _dummy_fail(self): - pass \ No newline at end of file +class LagAwareHealthCheck(AbstractHealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + rest_api_port: int = 9443, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + retry: Retry configuration for health checks + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + super().__init__( + retry=retry, + ) + self._http_client = HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=self.retry, + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password + ) + self._rest_api_port = rest_api_port + + def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint['dns_name'] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint['addr']: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..7989b41742 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,6 +1,6 @@ import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Optional from redis.exceptions import ConnectionError, TimeoutError @@ -73,6 +73,7 @@ def call_with_retry( self, do: Callable[[], T], fail: Callable[[Exception], Any], + is_retryable: Optional[Callable[[Exception], bool]] = None ) -> T: """ Execute an operation that might fail and returns its result, or @@ -86,6 +87,8 @@ def call_with_retry( try: return do() except self._supported_errors as error: + if is_retryable and not is_retryable(error): + raise failures += 1 fail(error) if self._retries >= 0 and failures > self._retries: diff --git a/redis/utils.py b/redis/utils.py index 715913e914..94bfab61bb 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -308,3 +308,9 @@ def truncate_text(txt, max_length=100): return textwrap.shorten( text=txt, width=max_length, placeholder="...", break_long_words=True ) + +def dummy_fail(): + """ + Fake function for a Retry object if you don't need to handle each failure. + """ + pass diff --git a/tests/test_http/__init__.py b/tests/test_http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_http/test_http_client.py b/tests/test_http/test_http_client.py new file mode 100644 index 0000000000..9a6d28ecd4 --- /dev/null +++ b/tests/test_http/test_http_client.py @@ -0,0 +1,324 @@ +import json +import gzip +from io import BytesIO +from typing import Any, Dict +from urllib.error import HTTPError +from urllib.parse import urlparse, parse_qs + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient, HttpError +from redis.retry import Retry + + +class FakeResponse: + def __init__(self, *, status: int, headers: Dict[str, str], url: str, content: bytes): + self.status = status + self.headers = headers + self._url = url + self._content = content + + def read(self) -> bytes: + return self._content + + def geturl(self) -> str: + return self._url + + # Support context manager used by urlopen + def __enter__(self) -> "FakeResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + +class TestHttpClient: + def test_get_returns_parsed_json_and_uses_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/items" + params = {"limit": 5, "q": "hello world"} + expected_url = f"{base_url}{path}?limit=5&q=hello+world" + payload: Dict[str, Any] = {"items": [1, 2, 3], "ok": True} + content = json.dumps(payload).encode("utf-8") + + captured_kwargs = {} + + def fake_urlopen(request, *, timeout=None, context=None): + # Capture call details for assertions + captured_kwargs["timeout"] = timeout + captured_kwargs["context"] = context + # Assert the request was constructed correctly + assert getattr(request, "method", "").upper() == "GET" + assert request.full_url == expected_url + # Return a successful response + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=content, + ) + + # Patch the urlopen used inside HttpClient + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path, params=params, timeout=12.34) # default expect_json=True + + # Assert + assert result == payload + assert pytest.approx(captured_kwargs["timeout"], rel=1e-6) == 12.34 + # HTTPS -> a context should be provided (created by ssl.create_default_context) + assert captured_kwargs["context"] is not None + + def test_get_handles_gzip_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "gzip-endpoint" + expected_url = f"{base_url}{path}" + payload = {"message": "compressed ok"} + raw = json.dumps(payload).encode("utf-8") + gzipped = gzip.compress(raw) + + def fake_urlopen(request, *, timeout=None, context=None): + # Return gzipped content with appropriate header + return FakeResponse( + status=200, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Content-Encoding": "gzip", + }, + url=expected_url, + content=gzipped, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path) # expect_json=True by default + + # Assert + assert result == payload + + def test_get_retries_on_retryable_http_errors_and_succeeds(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: configure limited retries so we can assert attempts + retry_policy = Retry(backoff=ExponentialWithJitterBackoff(base=0, cap=0), + retries=2) # 2 retries -> up to 3 attempts + base_url = "https://api.example.com/" + path = "sometimes-busy" + expected_url = f"{base_url}{path}" + payload = {"ok": True} + success_content = json.dumps(payload).encode("utf-8") + + call_count = {"n": 0} + + def make_http_error(url: str, code: int, body: bytes = b"busy"): + # Provide a file-like object for .read() when HttpClient tries to read error content + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Service Unavailable", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def flaky_urlopen(request, *, timeout=None, context=None): + call_count["n"] += 1 + # Fail with a retryable status (503) for the first two calls, then succeed + if call_count["n"] <= 2: + raise make_http_error(expected_url, 503) + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=success_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", flaky_urlopen) + + client = HttpClient(base_url=base_url, retry=retry_policy) + + # Act + result = client.get(path) + + # Assert: should have retried twice (total 3 attempts) and finally returned parsed JSON + assert result == payload + assert call_count["n"] == retry_policy.get_retries() + 1 + + def test_post_sends_json_body_and_parses_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/create" + expected_url = f"{base_url}{path}" + send_payload = {"a": 1, "b": "x"} + recv_payload = {"id": 10, "ok": True} + recv_content = json.dumps(recv_payload, separators=(",", ":")).encode("utf-8") + + def fake_urlopen(request, *, timeout=None, context=None): + # Verify method, URL and headers + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Content-Type should be auto-set for string JSON body + assert request.headers.get("Content-type") == "application/json; charset=utf-8" + # Body should be already UTF-8 encoded JSON with no spaces + assert request.data == json.dumps(send_payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=recv_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.post(path, json_body=send_payload) + + # Assert + assert result == recv_payload + + def test_post_with_raw_data_and_custom_headers(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "upload" + expected_url = f"{base_url}{path}" + raw_data = b"\x00\x01BINARY" + custom_headers = {"Content-type": "application/octet-stream", "X-extra": "1"} + recv_payload = {"status": "ok"} + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Ensure our provided headers are present + assert request.headers.get("Content-type") == "application/octet-stream" + assert request.headers.get("X-extra") == "1" + assert request.data == raw_data + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=expected_url, + content=json.dumps(recv_payload).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + # Act + result = client.post(path, data=raw_data, headers=custom_headers) + + # Assert + assert result == recv_payload + + def test_delete_returns_http_response_when_expect_json_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/resource/42" + expected_url = f"{base_url}{path}" + body = b"deleted" + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "DELETE" + assert request.full_url == expected_url + return FakeResponse( + status=204, + headers={"Content-Type": "text/plain"}, + url=expected_url, + content=body, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act + resp = client.delete(path, expect_json=False) + + # Assert + assert resp.status == 204 + assert resp.url == expected_url + assert resp.content == body + + def test_put_raises_http_error_on_non_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/update/1" + expected_url = f"{base_url}{path}" + + def make_http_error(url: str, code: int, body: bytes = b"not found"): + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Not Found", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def fake_urlopen(request, *, timeout=None, context=None): + raise make_http_error(expected_url, 404) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act / Assert + with pytest.raises(HttpError) as exc: + client.put(path, json_body={"x": 1}) + assert exc.value.status == 404 + assert exc.value.url == expected_url + + def test_patch_with_params_encodes_query(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/edit" + params = {"tag": ["a", "b"], "q": "hello world"} + + captured_url = {"u": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured_url["u"] = request.full_url + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"ok": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + client.patch(path, params=params) # We don't care about response here + + # Assert query parameters regardless of ordering + parsed = urlparse(captured_url["u"]) + qs = parse_qs(parsed.query) + assert qs["q"] == ["hello world"] + assert qs["tag"] == ["a", "b"] + + def test_request_low_level_headers_auth_and_timeout_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: use plain HTTP to verify no TLS context, and check default timeout used + base_url = "http://example.com/" + path = "ping" + captured = {"timeout": None, "context": "unset", "headers": None, "method": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured["timeout"] = timeout + captured["context"] = context + captured["headers"] = dict(request.headers) + captured["method"] = getattr(request, "method", "").upper() + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"pong": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url, auth_basic=("user", "pass")) + resp = client.request("GET", path) + + # Assert + assert resp.status == 200 + assert captured["method"] == "GET" + assert captured["context"] is None # no TLS for http + assert pytest.approx(captured["timeout"], rel=1e-6) == client.timeout # default used + # Check some default headers and Authorization presence + headers = {k.lower(): v for k, v in captured["headers"].items()} + assert "authorization" in headers and headers["authorization"].startswith("Basic ") + assert headers.get("accept") == "application/json" + assert "gzip" in headers.get("accept-encoding", "").lower() + assert "user-agent" in headers \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 37ee9b3fd3..193980d37c 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -6,14 +6,15 @@ from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 08bd8ab0c4..bc71fdb57d 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,6 +1,12 @@ +from unittest.mock import MagicMock + +import pytest + from redis.backoff import ExponentialBackoff from redis.multidb.database import Database from redis.multidb.healthcheck import EchoHealthCheck +from redis.http.http_client import HttpError +from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError from redis.retry import Retry @@ -38,4 +44,137 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 \ No newline at end of file + assert mock_client.execute_command.call_count == 3 + + +class TestLagAwareHealthCheck: + def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = MagicMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + rest_api_port=1234, + ) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.base_url == f"https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability" + assert second_call.kwargs.get("expect_json") is False + + def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = MagicMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability" + + def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, + {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index f0d2a0dbe3..6e7c344d85 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -4,15 +4,13 @@ import pybreaker import pytest -from redis.event import EventDispatcher -from redis.exceptions import ConnectionError from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF, DEFAULT_FAILURES_THRESHOLD +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 4182962fb1..a0f19e1a87 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -1,5 +1,7 @@ import json import os +import re +from urllib.parse import urlparse import pytest @@ -73,7 +75,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) ) db_configs.append(db_config) @@ -84,7 +87,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) ) db_configs.append(db_config1) @@ -93,10 +97,29 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, + health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, health_check_backoff=ExponentialBackoff(cap=5, base=0.5), - health_check_retries=3, ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config + + +def extract_cluster_fqdn(url): + """ + Extract Cluster FQDN from Redis URL + """ + # Parse the URL + parsed = urlparse(url) + + # Extract hostname and port + hostname = parsed.hostname + port = parsed.port + + # Remove the 'redis-XXXX.' prefix using regex + # This pattern matches 'redis-' followed by digits and a dot + cleaned_hostname = re.sub(r'^redis-\d+\.', '', hostname) + + # Reconstruct the URL + return f"https://{cleaned_hostname}" \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 967fa43cdb..44c57e6b99 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,5 +1,6 @@ import json import logging +import os import threading from time import sleep @@ -7,6 +8,7 @@ from redis import Redis, RedisCluster from redis.client import Pipeline +from redis.multidb.healthcheck import LagAwareHealthCheck from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -70,6 +72,48 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert r_multi_db.get('key') == 'value' sleep(0.5) + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,config,event) + ) + + env0_username = os.getenv('ENV0_USERNAME') + env0_password = os.getenv('ENV0_PASSWORD') + + # Adding additional health check to the client. + r_multi_db.add_health_check( + LagAwareHealthCheck(verify_tls=False, auth_basic=(env0_username,env0_password)) + ) + + # Client initialized on the first command. + r_multi_db.set('key', 'value') + thread.start() + + # Execute commands before network failure + while not event.is_set(): + assert r_multi_db.get('key') == 'value' + sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert r_multi_db.get('key') == 'value' + sleep(0.5) + @pytest.mark.parametrize( "r_multi_db", [ @@ -268,7 +312,7 @@ def handler(message): sleep(0.5) pubsub_thread.stop() - assert messages_count > 5 + assert messages_count > 2 @pytest.mark.parametrize( "r_multi_db", @@ -318,4 +362,4 @@ def handler(message): sleep(0.5) pubsub_thread.stop() - assert messages_count > 5 \ No newline at end of file + assert messages_count > 2 \ No newline at end of file From f9fdc994796c194a935103b0e718d3cefaa5f2ad Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 26 Aug 2025 12:34:45 +0300 Subject: [PATCH 08/50] Added lag_aware_tolerance parameter to LagAwareHealthcheck (#3752) --- redis/multidb/healthcheck.py | 6 +++++- tests/test_multidb/test_healthcheck.py | 7 +++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 63ba415334..9818d06e28 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -86,6 +86,7 @@ def __init__( self, retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, + lag_aware_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, @@ -104,6 +105,7 @@ def __init__( Args: retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) auth_basic: Tuple of (username, password) for basic authentication verify_tls: Whether to verify TLS certificates (default: True) @@ -130,6 +132,7 @@ def __init__( client_key_password=client_key_password ) self._rest_api_port = rest_api_port + self._lag_aware_tolerance = lag_aware_tolerance def check_health(self, database) -> bool: if database.health_check_url is None: @@ -163,7 +166,8 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + url = (f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index bc71fdb57d..18bfe5f23b 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -4,7 +4,6 @@ from redis.backoff import ExponentialBackoff from redis.multidb.database import Database -from redis.multidb.healthcheck import EchoHealthCheck from redis.http.http_client import HttpError from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.multidb.circuit import State as CBState @@ -74,7 +73,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc hc = LagAwareHealthCheck( retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - rest_api_port=1234, + rest_api_port=1234, lag_aware_tolerance=150 ) # Inject our mocked http client hc._http_client = mock_http @@ -89,7 +88,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability" + assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=150" assert second_call.kwargs.get("expect_json") is False def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): @@ -121,7 +120,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability" + assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=100" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ From 866003bc0abcc07db79d39c69b55153645df9448 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:56:10 +0300 Subject: [PATCH 09/50] Extract additional interfaces and abstract classes (#3754) --- redis/multidb/circuit.py | 82 ++++++----- redis/multidb/client.py | 25 ++-- redis/multidb/command_executor.py | 152 ++++++++++---------- redis/multidb/config.py | 8 +- redis/multidb/database.py | 100 +++++++------ redis/multidb/event.py | 13 +- redis/multidb/failover.py | 11 +- redis/multidb/failure_detector.py | 1 - tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 +- tests/test_multidb/test_config.py | 10 +- tests/test_multidb/test_failure_detector.py | 12 +- 13 files changed, 225 insertions(+), 209 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 79c8a5f379..221dc556a3 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,8 +45,49 @@ def database(self, database): """Set database associated with this circuit.""" pass +class BaseCircuitBreaker(CircuitBreaker): + """ + Base implementation of Circuit Breaker interface. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + +class SyncCircuitBreaker(CircuitBreaker): + """ + Synchronous implementation of Circuit Breaker interface. + """ @abstractmethod - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -54,7 +95,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[CircuitBreaker, State, State], None], + cb: Callable[[SyncCircuitBreaker, State, State], None], database, ): """ @@ -75,8 +116,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) - -class PBCircuitBreakerAdapter(CircuitBreaker): +class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -87,38 +127,8 @@ def __init__(self, cb: pybreaker.CircuitBreaker): Args: cb: A pybreaker CircuitBreaker instance to be adapted. """ - self._cb = cb - self._state_pb_mapper = { - State.CLOSED: self._cb.close, - State.OPEN: self._cb.open, - State.HALF_OPEN: self._cb.half_open, - } - self._database = None - - @property - def grace_period(self) -> float: - return self._cb.reset_timeout - - @grace_period.setter - def grace_period(self, grace_period: float): - self._cb.reset_timeout = grace_period - - @property - def state(self) -> State: - return State(value=self._cb.state.name) - - @state.setter - def state(self, state: State): - self._state_pb_mapper[state]() - - @property - def database(self): - return self._database - - @database.setter - def database(self, database): - self._database = database + super().__init__(cb) - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 56342a7a53..8a0e006977 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,15 +1,12 @@ import threading -import socket from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler -from redis.client import PubSubWorkerThread -from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -92,7 +89,7 @@ def get_databases(self) -> Databases: """ return self._databases - def set_active_database(self, database: AbstractDatabase) -> None: + def set_active_database(self, database: SyncDatabase) -> None: """ Promote one of the existing databases to become an active. """ @@ -115,7 +112,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: raise NoValidDatabaseException('Cannot set active database, database is unhealthy') - def add_database(self, database: AbstractDatabase): + def add_database(self, database: SyncDatabase): """ Adds a new database to the database list. """ @@ -129,7 +126,7 @@ def add_database(self, database: AbstractDatabase): self._databases.add(database, database.weight) self._change_active_database(database, highest_weighted_db) - def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: self.command_executor.active_database = new_database @@ -143,7 +140,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: self.command_executor.active_database = highest_weighted_db - def update_database_weight(self, database: AbstractDatabase, weight: float): + def update_database_weight(self, database: SyncDatabase, weight: float): """ Updates a database from the database list. """ @@ -210,7 +207,7 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. """ @@ -247,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -255,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: CircuitBreaker): +def _half_open_circuit(circuit: SyncCircuitBreaker): circuit.state = CBState.HALF_OPEN @@ -450,8 +447,8 @@ def run_in_thread( exception_handler: Optional[Callable] = None, sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": - return self._client.command_executor.execute_pubsub_run_in_thread( - sleep_time=sleep_time, + return self._client.command_executor.execute_pubsub_run( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=self, diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 094230a31d..364c0a07ea 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Any from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy @@ -17,15 +17,40 @@ class CommandExecutor(ABC): @property @abstractmethod - def failure_detectors(self) -> List[FailureDetector]: - """Returns a list of failure detectors.""" + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" pass + @auto_fallback_interval.setter @abstractmethod - def add_failure_detector(self, failure_detector: FailureDetector) -> None: - """Adds new failure detector to the list of failure detectors.""" + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" pass +class BaseCommandExecutor(CommandExecutor): + def __init__( + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + +class SyncCommandExecutor(CommandExecutor): + @property @abstractmethod def databases(self) -> Databases: @@ -34,19 +59,25 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Optional[Database]: - """Returns currently active database.""" + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AbstractDatabase) -> None: - """Sets currently active database.""" + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" pass + @property @abstractmethod - def pubsub(self, **kwargs): - """Initializes a PubSub object on a currently active database""" + def active_database(self) -> Optional[Database]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: SyncDatabase) -> None: + """Sets the currently active database.""" pass @property @@ -69,30 +100,41 @@ def failover_strategy(self) -> FailoverStrategy: @property @abstractmethod - def auto_fallback_interval(self) -> float: - """Returns auto-fallback interval.""" + def command_retry(self) -> Retry: + """Returns command retry object.""" pass - @auto_fallback_interval.setter @abstractmethod - def auto_fallback_interval(self, auto_fallback_interval: float) -> None: - """Sets auto-fallback interval.""" + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" pass - @property @abstractmethod - def command_retry(self) -> Retry: - """Returns command retry object.""" + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" pass @abstractmethod - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" + def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" pass + @abstractmethod + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass -class DefaultCommandExecutor(CommandExecutor): + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + +class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): def __init__( self, failure_detectors: List[FailureDetector], @@ -113,22 +155,26 @@ def __init__( event_dispatcher: Interface for dispatching events auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ + super().__init__(auto_fallback_interval) + for fd in failure_detectors: fd.set_command_executor(command_executor=self) - self._failure_detectors = failure_detectors self._databases = databases + self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy = failover_strategy self._event_dispatcher = event_dispatcher - self._auto_fallback_interval = auto_fallback_interval - self._next_fallback_attempt: datetime self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() + @property + def databases(self) -> Databases: + return self._databases + @property def failure_detectors(self) -> List[FailureDetector]: return self._failure_detectors @@ -136,20 +182,16 @@ def failure_detectors(self) -> List[FailureDetector]: def add_failure_detector(self, failure_detector: FailureDetector) -> None: self._failure_detectors.append(failure_detector) - @property - def databases(self) -> Databases: - return self._databases - @property def command_retry(self) -> Retry: return self._command_retry @property - def active_database(self) -> Optional[AbstractDatabase]: + def active_database(self) -> Optional[SyncDatabase]: return self._active_database @active_database.setter - def active_database(self, database: AbstractDatabase) -> None: + def active_database(self, database: SyncDatabase) -> None: old_active = self._active_database self._active_database = database @@ -170,25 +212,13 @@ def active_pubsub(self, pubsub: PubSub) -> None: def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy - @property - def auto_fallback_interval(self) -> float: - return self._auto_fallback_interval - - @auto_fallback_interval.setter - def auto_fallback_interval(self, auto_fallback_interval: int) -> None: - self._auto_fallback_interval = auto_fallback_interval - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) return self._execute_with_failure_detection(callback, args) def execute_pipeline(self, command_stack: tuple): - """ - Executes a stack of commands in pipeline. - """ def callback(): with self._active_database.client.pipeline() as pipe: for command, options in command_stack: @@ -199,18 +229,12 @@ def callback(): return self._execute_with_failure_detection(callback, command_stack) def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): - """ - Executes a transaction block wrapped in callback. - """ def callback(): return self._active_database.client.transaction(transaction, *watches, **options) return self._execute_with_failure_detection(callback) def pubsub(self, **kwargs): - """ - Initializes a PubSub object on a currently active database. - """ def callback(): if self._active_pubsub is None: self._active_pubsub = self._active_database.client.pubsub(**kwargs) @@ -220,31 +244,15 @@ def callback(): return self._execute_with_failure_detection(callback) def execute_pubsub_method(self, method_name: str, *args, **kwargs): - """ - Executes given method on active pub/sub. - """ def callback(): method = getattr(self.active_pubsub, method_name) return method(*args, **kwargs) return self._execute_with_failure_detection(callback, *args) - def execute_pubsub_run_in_thread( - self, - pubsub, - sleep_time: float = 0.0, - daemon: bool = False, - exception_handler: Optional[Callable] = None, - sharded_pubsub: bool = False, - ) -> "PubSubWorkerThread": + def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread": def callback(): - return self._active_pubsub.run_in_thread( - sleep_time, - daemon=daemon, - exception_handler=exception_handler, - pubsub=pubsub, - sharded_pubsub=sharded_pubsub - ) + return self._active_pubsub.run_in_thread(sleep_time, **kwargs) return self._execute_with_failure_detection(callback) @@ -280,12 +288,6 @@ def _check_active_database(self): self.active_database = self._failover_strategy.database self._schedule_next_fallback() - def _schedule_next_fallback(self) -> None: - if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: - return - - self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) - def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 5555baec44..a966ec329a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[CircuitBreaker] = None + circuit: Optional[SyncCircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> CircuitBreaker: + def default_circuit_breaker(self) -> SyncCircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index b03e77bd70..75a662d904 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,65 +5,92 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import SyncCircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @property @abstractmethod - def client(self) -> Union[redis.Redis, RedisCluster]: - """The underlying redis client.""" + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" pass - @client.setter + @weight.setter @abstractmethod - def client(self, client: Union[redis.Redis, RedisCluster]): - """Set the underlying redis client.""" + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" pass @property @abstractmethod - def weight(self) -> float: - """The weight of this database in compare to others. Used to determine the database failover to.""" + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" pass - @weight.setter + @health_check_url.setter @abstractmethod - def weight(self, weight: float): - """Set the weight of this database in compare to others.""" + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" pass +class BaseDatabase(AbstractDatabase): + def __init__( + self, + weight: float, + health_check_url: Optional[str] = None, + ): + self._weight = weight + self._health_check_url = health_check_url + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url + +class SyncDatabase(AbstractDatabase): + """Database with an underlying synchronous redis client.""" @property @abstractmethod - def circuit(self) -> CircuitBreaker: - """Circuit breaker for the current database.""" + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" pass - @circuit.setter + @client.setter @abstractmethod - def circuit(self, circuit: CircuitBreaker): - """Set the circuit breaker for the current database.""" + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" pass @property @abstractmethod - def health_check_url(self) -> Optional[str]: - """Health check URL associated with the current database.""" + def circuit(self) -> SyncCircuitBreaker: + """Circuit breaker for the current database.""" pass - @health_check_url.setter + @circuit.setter @abstractmethod - def health_check_url(self, health_check_url: Optional[str]): - """Set the health check URL associated with the current database.""" + def circuit(self, circuit: SyncCircuitBreaker): + """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, Number]] +Databases = WeightedList[tuple[SyncDatabase, Number]] -class Database(AbstractDatabase): +class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: CircuitBreaker, + circuit: SyncCircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -79,8 +106,7 @@ def __init__( self._client = client self._cb = circuit self._cb.database = self - self._weight = weight - self._health_check_url = health_check_url + super().__init__(weight, health_check_url) @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,25 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def weight(self) -> float: - return self._weight - - @weight.setter - def weight(self, weight: float): - self._weight = weight - - @property - def circuit(self) -> CircuitBreaker: + def circuit(self) -> SyncCircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: CircuitBreaker): - self._cb = circuit - - @property - def health_check_url(self) -> Optional[str]: - return self._health_check_url - - @health_check_url.setter - def health_check_url(self, health_check_url: Optional[str]): - self._health_check_url = health_check_url + def circuit(self, circuit: SyncCircuitBreaker): + self._cb = circuit \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 2598bc4d06..bca9482347 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,7 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent -from redis.multidb.config import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector class ActiveDatabaseChanged: @@ -11,8 +10,8 @@ class ActiveDatabaseChanged: """ def __init__( self, - old_database: AbstractDatabase, - new_database: AbstractDatabase, + old_database: SyncDatabase, + new_database: SyncDatabase, command_executor, **kwargs ): @@ -22,11 +21,11 @@ def __init__( self._kwargs = kwargs @property - def old_database(self) -> AbstractDatabase: + def old_database(self) -> SyncDatabase: return self._old_database @property - def new_database(self) -> AbstractDatabase: + def new_database(self) -> SyncDatabase: return self._new_database @property @@ -39,7 +38,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe currently active pub/sub to a new active database. + Re-subscribe the currently active pub / sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index d6cf198678..fd08b77ecd 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from redis.data_structure import WeightedList -from redis.multidb.database import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry @@ -13,13 +12,13 @@ class FailoverStrategy(ABC): @property @abstractmethod - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: """Select the database according to the strategy.""" pass @abstractmethod def set_databases(self, databases: Databases) -> None: - """Set the databases strategy operates on.""" + """Set the database strategy operates on.""" pass class WeightBasedFailoverStrategy(FailoverStrategy): @@ -35,7 +34,7 @@ def __init__( self._databases = WeightedList() @property - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), lambda _: dummy_fail() @@ -44,7 +43,7 @@ def database(self) -> AbstractDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases - def _get_active_database(self) -> AbstractDatabase: + def _get_active_database(self) -> SyncDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 3280fa6c32..ef4bd35f69 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -24,7 +24,6 @@ class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. """ - def __init__( self, threshold: int, diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index a34ef01476..9503d79d9b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> CircuitBreaker: - return Mock(spec=CircuitBreaker) +def mock_cb() -> SyncCircuitBreaker: + return Mock(spec=SyncCircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 7dc642373b..f5f39c3f6b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 193980d37c..c7c15fe684 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -458,7 +458,7 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK' with pytest.raises(ValueError, match='Given database is not a member of database list'): - client.set_active_database(Mock(spec=AbstractDatabase)) + client.set_active_database(Mock(spec=SyncDatabase)) mock_hc.check_health.return_value = False diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 87aae701a9..e428b3ce7a 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1 = Mock(spec=SyncCircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2 = Mock(spec=SyncCircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3 = Mock(spec=SyncCircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=CircuitBreaker) + mock_circuit = Mock(spec=SyncCircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 86d6e1cd82..28687f2a11 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -3,7 +3,7 @@ import pytest -from redis.multidb.command_executor import CommandExecutor +from redis.multidb.command_executor import SyncCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -19,7 +19,7 @@ class TestCommandFailureDetector: ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -41,7 +41,7 @@ def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exce ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -62,7 +62,7 @@ def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interv ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -96,7 +96,7 @@ def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_e ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -128,7 +128,7 @@ def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED From ec8113b38441181e6a41a67df39abe5663a4817b Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:24:24 +0300 Subject: [PATCH 10/50] Added async implementation of MultiDBClient (#3762) * Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Update redis/asyncio/multidb/healthcheck.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/test_asyncio/test_scenario/test_active_active.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- redis/asyncio/multidb/__init__.py | 0 redis/asyncio/multidb/client.py | 237 +++++++++ redis/asyncio/multidb/command_executor.py | 265 ++++++++++ redis/asyncio/multidb/config.py | 169 +++++++ redis/asyncio/multidb/database.py | 67 +++ redis/asyncio/multidb/event.py | 65 +++ redis/asyncio/multidb/failover.py | 49 ++ redis/asyncio/multidb/failure_detector.py | 29 ++ redis/asyncio/multidb/healthcheck.py | 75 +++ redis/background.py | 52 +- redis/event.py | 15 +- redis/multidb/circuit.py | 17 +- redis/multidb/client.py | 6 +- redis/multidb/config.py | 8 +- redis/multidb/database.py | 12 +- redis/utils.py | 6 + tests/test_asyncio/test_multidb/__init__.py | 0 tests/test_asyncio/test_multidb/conftest.py | 108 ++++ .../test_asyncio/test_multidb/test_client.py | 471 ++++++++++++++++++ .../test_multidb/test_command_executor.py | 165 ++++++ .../test_asyncio/test_multidb/test_config.py | 125 +++++ .../test_multidb/test_failover.py | 121 +++++ .../test_multidb/test_failure_detector.py | 153 ++++++ .../test_multidb/test_healthcheck.py | 48 ++ tests/test_asyncio/test_scenario/__init__.py | 0 tests/test_asyncio/test_scenario/conftest.py | 88 ++++ .../test_scenario/test_active_active.py | 59 +++ tests/test_background.py | 33 ++ tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 - tests/test_multidb/test_config.py | 10 +- 32 files changed, 2429 insertions(+), 44 deletions(-) create mode 100644 redis/asyncio/multidb/__init__.py create mode 100644 redis/asyncio/multidb/client.py create mode 100644 redis/asyncio/multidb/command_executor.py create mode 100644 redis/asyncio/multidb/config.py create mode 100644 redis/asyncio/multidb/database.py create mode 100644 redis/asyncio/multidb/event.py create mode 100644 redis/asyncio/multidb/failover.py create mode 100644 redis/asyncio/multidb/failure_detector.py create mode 100644 redis/asyncio/multidb/healthcheck.py create mode 100644 tests/test_asyncio/test_multidb/__init__.py create mode 100644 tests/test_asyncio/test_multidb/conftest.py create mode 100644 tests/test_asyncio/test_multidb/test_client.py create mode 100644 tests/test_asyncio/test_multidb/test_command_executor.py create mode 100644 tests/test_asyncio/test_multidb/test_config.py create mode 100644 tests/test_asyncio/test_multidb/test_failover.py create mode 100644 tests/test_asyncio/test_multidb/test_failure_detector.py create mode 100644 tests/test_asyncio/test_multidb/test_healthcheck.py create mode 100644 tests/test_asyncio/test_scenario/__init__.py create mode 100644 tests/test_asyncio/test_scenario/conftest.py create mode 100644 tests/test_asyncio/test_scenario/test_active_active.py diff --git a/redis/asyncio/multidb/__init__.py b/redis/asyncio/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py new file mode 100644 index 0000000000..73eafd9026 --- /dev/null +++ b/redis/asyncio/multidb/client.py @@ -0,0 +1,237 @@ +import asyncio +from typing import Callable, Optional, Coroutine, Any + +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.background import BackgroundScheduler +from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands +from redis.multidb.exception import NoValidDatabaseException + + +class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors([ConnectionRefusedError]) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = asyncio.Lock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + + async def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + async def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + await self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + )) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + await self.command_executor.set_active_database(database) + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + async def set_active_database(self, database: AsyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + await self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + await self.command_executor.set_active_database(database) + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + async def add_database(self, database: AsyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + await self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + await self._change_active_database(database, highest_weighted_db) + + async def _change_active_database(self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(new_database) + + async def remove_database(self, database: AsyncDatabase): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(highest_weighted_db) + + async def update_database_weight(self, database: AsyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + await self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: AsyncFailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + async def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + async with self._hc_lock: + self._health_checks.append(healthcheck) + + async def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_command(*args, **options) + + async def _check_databases_health( + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + for database, _ in self._databases: + async with self._hc_lock: + await self._check_db_health(database, on_error) + + async def _check_db_health( + self, + database: AsyncDatabase, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = await health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except Exception as e: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + if on_error: + await on_error(e) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + loop = asyncio.get_running_loop() + + if new_state == CBState.HALF_OPEN: + asyncio.create_task(self._check_db_health(circuit.database)) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py new file mode 100644 index 0000000000..af10a00988 --- /dev/null +++ b/redis/asyncio/multidb/command_executor.py @@ -0,0 +1,265 @@ +from abc import abstractmethod +from datetime import datetime +from typing import List, Optional, Callable, Any + +from redis.asyncio.client import PubSub, Pipeline +from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ + ResubscribeOnActiveDatabaseChanged +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent +from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL + + +class AsyncCommandExecutor(CommandExecutor): + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[AsyncFailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[AsyncDatabase]: + """Returns currently active database.""" + pass + + @abstractmethod + async def set_active_database(self, database: AsyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> AsyncFailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + async def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + async def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + async def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + + +class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): + def __init__( + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[AsyncFailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def active_database(self) -> Optional[AsyncDatabase]: + return self._active_database + + async def set_active_database(self, database: AsyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + await self._event_dispatcher.dispatch_async( + AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy(self) -> AsyncFailoverStrategy: + return self._failover_strategy + + @property + def command_retry(self) -> Retry: + return self._command_retry + + async def pubsub(self, **kwargs): + async def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return await self._execute_with_failure_detection(callback) + + async def execute_command(self, *args, **options): + async def callback(): + return await self._active_database.client.execute_command(*args, **options) + + return await self._execute_with_failure_detection(callback, args) + + async def execute_pipeline(self, command_stack: tuple): + async def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + await pipe.execute_command(*command, **options) + + return await pipe.execute() + + return await self._execute_with_failure_detection(callback, command_stack) + + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def callback(): + return await self._active_database.client.transaction(transaction, *watches, **options) + + return await self._execute_with_failure_detection(callback) + + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def callback(): + method = getattr(self.active_pubsub, method_name) + return await method(*args, **kwargs) + + return await self._execute_with_failure_detection(callback, *args) + + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def callback(): + return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs) + + return await self._execute_with_failure_detection(callback) + + async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + async def wrapper(): + # On each retry we need to check active database as it might change. + await self._check_active_database() + return await callback() + + return await self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + async def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + await self.set_active_database(await self._failover_strategy.database()) + self._schedule_next_fallback() + + async def _on_command_fail(self, error, *args): + await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners({ + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [resubscribe_listener], + }) \ No newline at end of file diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py new file mode 100644 index 0000000000..b5f4a0658d --- /dev/null +++ b/redis/asyncio/multidb/config.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Type, Union + +import pybreaker + +from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio.multidb.database import Databases, Database +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ + EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.failure_detector import CommandFailureDetector + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + failure_threshold: Threshold for determining database failure. + failures_interval: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_retries: Number of retry attempts for performing health checks. + health_check_backoff: Backoff strategy for health check retries. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_retries: Number of retries allowed for failover operations. + failover_backoff: Backoff strategy for failover retries. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[AsyncFailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[AsyncFailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + + if database_config.from_url: + client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + elif database_config.from_pool: + database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) + client = self.client_class.from_pool(connection_pool=database_config.from_pool) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), + database_config.weight + ) + + return databases + + def default_failure_detectors(self) -> List[AsyncFailureDetector]: + return [ + FailureDetectorAsyncWrapper( + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval) + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> AsyncFailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) \ No newline at end of file diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py new file mode 100644 index 0000000000..6afbbbf5ea --- /dev/null +++ b/redis/asyncio/multidb/database.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Union, Optional + +from redis.asyncio import Redis, RedisCluster +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.database import AbstractDatabase, BaseDatabase +from redis.typing import Number + + +class AsyncDatabase(AbstractDatabase): + """Database with an underlying asynchronous redis client.""" + @property + @abstractmethod + def client(self) -> Union[Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AsyncDatabase, Number]] + +class Database(BaseDatabase, AsyncDatabase): + def __init__( + self, + client: Union[Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit + diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py new file mode 100644 index 0000000000..ea5534ce86 --- /dev/null +++ b/redis/asyncio/multidb/event.py @@ -0,0 +1,65 @@ +from typing import List + +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent + + +class AsyncActiveDatabaseChanged: + """ + Event fired when an async active database has been changed. + """ + def __init__( + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AsyncDatabase: + return self._old_database + + @property + def new_database(self) -> AsyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + async def listen(self, event: AsyncActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + await new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + await old_pubsub.close() + +class RegisterCommandFailure(AsyncEventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[AsyncFailureDetector]): + self._failure_detectors = failure_detectors + + async def listen(self, event: AsyncOnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + await failure_detector.register_failure(event.exception, event.commands) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py new file mode 100644 index 0000000000..a2ed427e05 --- /dev/null +++ b/redis/asyncio/multidb/failover.py @@ -0,0 +1,49 @@ +from abc import abstractmethod, ABC + +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.data_structure import WeightedList +from redis.multidb.exception import NoValidDatabaseException +from redis.utils import dummy_fail_async + + +class AsyncFailoverStrategy(ABC): + + @abstractmethod + async def database(self) -> AsyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(AsyncFailoverStrategy): + """ + Failover strategy based on database weights. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + async def database(self) -> AsyncDatabase: + return await self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: dummy_fail_async() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + async def _get_active_database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py new file mode 100644 index 0000000000..8aa4752924 --- /dev/null +++ b/redis/asyncio/multidb/failure_detector.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from redis.multidb.failure_detector import FailureDetector + + +class AsyncFailureDetector(ABC): + + @abstractmethod + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + +class FailureDetectorAsyncWrapper(AsyncFailureDetector): + """ + Async wrapper for the failure detector. + """ + def __init__(self, failure_detector: FailureDetector) -> None: + self._failure_detector = failure_detector + + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + self._failure_detector.register_failure(exception, cmd) + + def set_command_executor(self, command_executor) -> None: + self._failure_detector.set_command_executor(command_executor) \ No newline at end of file diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py new file mode 100644 index 0000000000..ccaf285ade --- /dev/null +++ b/redis/asyncio/multidb/healthcheck.py @@ -0,0 +1,75 @@ +import logging +from abc import ABC, abstractmethod + +from redis.asyncio import Redis +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff +from redis.utils import dummy_fail_async + +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + async def check_health(self, database) -> bool: + pass + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + async def check_health(self, database) -> bool: + return await self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: dummy_fail_async() + ) + + async def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = await database.client.execute_command("ECHO", "healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = await node.redis_connection.execute_command("ECHO", "healthcheck") + + if actual_message not in expected_message: + return False + + return True \ No newline at end of file diff --git a/redis/background.py b/redis/background.py index 6466649859..ce43cbfa7a 100644 --- a/redis/background.py +++ b/redis/background.py @@ -1,6 +1,7 @@ import asyncio import threading -from typing import Callable +from typing import Callable, Coroutine, Any + class BackgroundScheduler: """ @@ -45,7 +46,35 @@ def run_recurring( ) thread.start() - def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + async def run_recurring_async( + self, + interval: float, + coro: Callable[..., Coroutine[Any, Any, Any]], + *args + ): + """ + Runs recurring coroutine with given interval in seconds in the current event loop. + To be used only from an async context. No additional threads are created. + """ + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, coro, *args) + + def tick(): + # Schedule the coroutine + wrapped() + # Schedule next tick + self._next_timer = loop.call_later(interval, tick) + + # Schedule first tick + self._next_timer = loop.call_later(interval, tick) + + def _call_later( + self, + loop: asyncio.AbstractEventLoop, + delay: float, + callback: Callable, + *args + ): self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( @@ -86,4 +115,21 @@ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon """ asyncio.set_event_loop(event_loop) event_loop.call_soon(call_soon_cb, event_loop, *args) - event_loop.run_forever() \ No newline at end of file + event_loop.run_forever() + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 1fa66f0587..de38e1a069 100644 --- a/redis/event.py +++ b/redis/event.py @@ -43,7 +43,10 @@ async def dispatch_async(self, event: object): pass @abstractmethod - def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): """Register additional listeners.""" pass @@ -99,13 +102,16 @@ def dispatch(self, event: object): listener.listen(event) async def dispatch_async(self, event: object): - with self._async_lock: + async with self._async_lock: listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: await listener.listen(event) - def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + event_listeners: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): with self._lock: for event_type in event_listeners: if event_type in self._event_listeners_mapping: @@ -271,6 +277,9 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception +class AsyncOnCommandsFailEvent(OnCommandsFailEvent): + pass + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 221dc556a3..8f904c0e4b 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,6 +45,11 @@ def database(self, database): """Set database associated with this circuit.""" pass + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + class BaseCircuitBreaker(CircuitBreaker): """ Base implementation of Circuit Breaker interface. @@ -82,12 +87,8 @@ def database(self): def database(self, database): self._database = database -class SyncCircuitBreaker(CircuitBreaker): - """ - Synchronous implementation of Circuit Breaker interface. - """ @abstractmethod - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -95,7 +96,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[SyncCircuitBreaker, State, State], None], + cb: Callable[[CircuitBreaker, State, State], None], database, ): """ @@ -116,7 +117,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) -class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): +class PBCircuitBreakerAdapter(BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -129,6 +130,6 @@ def __init__(self, cb: pybreaker.CircuitBreaker): """ super().__init__(cb) - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8a0e006977..71e079346a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -5,7 +5,7 @@ from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector @@ -244,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -252,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_sta if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: SyncCircuitBreaker): +def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/config.py b/redis/multidb/config.py index a966ec329a..fc349ed04b 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[SyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> SyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 75a662d904..9c2ffe3552 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,7 +5,7 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import SyncCircuitBreaker +from redis.multidb.circuit import CircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @@ -74,13 +74,13 @@ def client(self, client: Union[redis.Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -90,7 +90,7 @@ class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: SyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -117,9 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit \ No newline at end of file diff --git a/redis/utils.py b/redis/utils.py index 94bfab61bb..1800582e46 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -314,3 +314,9 @@ def dummy_fail(): Fake function for a Retry object if you don't need to handle each failure. """ pass + +async def dummy_fail_async(): + """ + Async fake function for a Retry object if you don't need to handle each failure. + """ + pass \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/__init__.py b/tests/test_asyncio/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py new file mode 100644 index 0000000000..0ac231cf52 --- /dev/null +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -0,0 +1,108 @@ +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ + DatabaseConfig +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio import Redis +from redis.asyncio.multidb.database import Database, Databases + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + +@pytest.fixture() +def mock_fd() -> AsyncFailureDetector: + return Mock(spec=AsyncFailureDetector) + +@pytest.fixture() +def mock_fs() -> AsyncFailoverStrategy: + return Mock(spec=AsyncFailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py new file mode 100644 index 0000000000..c2fe914e9f --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -0,0 +1,471 @@ +import asyncio +from unittest.mock import patch, AsyncMock, Mock + +import pybreaker +import pytest + +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck +from redis.asyncio.retry import Retry +from redis.event import EventDispatcher, AsyncOnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.exception import NoValidDatabaseException +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.22) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + await client.set('key', 'value') + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + await client.add_database(mock_db) + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK2' + assert mock_hc.check_health.call_count == 2 + + await client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 3 + + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.remove_database(mock_db1) + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = AsyncOnCommandsFailEvent( + commands=('SET', 'key', 'value'), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 4 + assert another_hc.check_health.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.set_active_database(mock_db) + assert await client.set('key', 'value') == 'OK' + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + await client.set_active_database(Mock(spec=AsyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + await client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..3f64e6aa0b --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -0,0 +1,165 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + await executor.set_active_database(mock_db1) + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + + await executor.set_active_database(mock_db2) + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command = AsyncMock(side_effect=['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1']) + mock_db2.client.execute_command = AsyncMock(side_effect=['OK2', ConnectionError, ConnectionError, ConnectionError]) + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(threshold, 1)) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py new file mode 100644 index 0000000000..64760740a1 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -0,0 +1,125 @@ +from unittest.mock import Mock + +from redis.asyncio import ConnectionPool +from redis.asyncio.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_GRACE_PERIOD, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper, AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.retry import Retry +from redis.multidb.circuit import CircuitBreaker + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i+=1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], FailureDetectorAsyncWrapper) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=AsyncFailureDetector), Mock(spec=AsyncFailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), CircuitBreaker) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py new file mode 100644 index 0000000000..f692c40643 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -0,0 +1,121 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.retry import Retry + + +class TestAsyncWeightBasedFailoverStrategy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + strategy = WeightBasedFailoverStrategy(retry=retry) + strategy.set_databases(databases) + + assert await strategy.database() == mock_db1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert await failover_strategy.database() == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() + + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..3c1eb4fabd --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -0,0 +1,153 @@ +import asyncio +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector + + +class TestFailureDetectorAsyncWrapper: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failures as the last one already refreshed timer + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..fd5c8ec3f0 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -0,0 +1,48 @@ +import pytest +from mock.mock import AsyncMock + +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestEchoHealthCheck: + + @pytest.mark.asyncio + async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/__init__.py b/tests/test_asyncio/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py new file mode 100644 index 0000000000..312712ba05 --- /dev/null +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ + MultiDbConfig +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.event import AsyncEventListenerInterface, EventDispatcher +from tests.test_scenario.conftest import get_endpoint_config, extract_cluster_fqdn +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + async def listen(self, event: AsyncActiveDatabaseChanged): + self.is_changed_flag = True + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + +@pytest.fixture() +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + + username = endpoint_config.get('username', None) + password = endpoint_config.get('password', None) + failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + AsyncActiveDatabaseChanged: [listener], + }) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config['endpoints'][0], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config['endpoints'][1], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + failure_threshold=failure_threshold, + health_check_retries=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + ) + + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py new file mode 100644 index 0000000000..518e9561d9 --- /dev/null +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -0,0 +1,59 @@ +import asyncio +import logging +from time import sleep + +import pytest + +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + +async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result['action_id']) + + while status_result['status'] != "success": + await asyncio.sleep(0.1) + status_result = fault_injector_client.get_action_status(result['action_id']) + logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + +class TestActiveActive: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(5) + + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, config, event)) + + # Client initialized on the first command. + await r_multi_db.set('key', 'value') + + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) \ No newline at end of file diff --git a/tests/test_background.py b/tests/test_background.py index 4b3a5377c1..ba62e5bdd9 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,3 +1,4 @@ +import asyncio from time import sleep import pytest @@ -57,4 +58,36 @@ def callback(arg1: str, arg2: int): sleep(timeout) + assert execute_counter == call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + async def test_run_recurring_async(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + async def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + await scheduler.run_recurring_async(interval, callback, one, two) + assert execute_counter == 0 + + await asyncio.sleep(timeout) + assert execute_counter == call_count \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 9503d79d9b..0c082f0f17 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> SyncCircuitBreaker: - return Mock(spec=SyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index f5f39c3f6b..7dc642373b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c7c15fe684..d352c1da92 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -166,13 +166,9 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) - assert client.set('key', 'value') == 'OK2' - sleep(0.22) - assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index e428b3ce7a..1ea63a0e14 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=SyncCircuitBreaker) + mock_cb1 = Mock(spec=CircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=SyncCircuitBreaker) + mock_cb2 = Mock(spec=CircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=SyncCircuitBreaker) + mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=SyncCircuitBreaker) + mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit From 1dfffd201a98311f8a9d328d97b85ae76c175cef Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:41:57 +0300 Subject: [PATCH 11/50] Added pipeline and transaction support for MultiDBClient (#3763) * Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Updated scenario tests to check failover --- redis/asyncio/multidb/client.py | 114 ++++++- redis/asyncio/multidb/command_executor.py | 24 +- .../test_multidb/test_pipeline.py | 321 ++++++++++++++++++ tests/test_asyncio/test_scenario/conftest.py | 8 +- .../test_scenario/test_active_active.py | 130 +++++++ 5 files changed, 585 insertions(+), 12 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_pipeline.py diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 73eafd9026..1025c4b37b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,5 +1,5 @@ import asyncio -from typing import Callable, Optional, Coroutine, Any +from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases @@ -10,6 +10,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException +from redis.typing import KeyT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig): self._hc_lock = asyncio.Lock() self._bg_scheduler = BackgroundScheduler() self._config = config + self._hc_task = None + self._half_open_state_task = None + + async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": + if not self.initialized: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._hc_task: + self._hc_task.cancel() + if self._half_open_state_task: + self._half_open_state_task.cancel() async def initialize(self): """ @@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, )) @@ -180,6 +194,34 @@ async def execute_command(self, *args, **options): return await self.command_executor.execute_command(*args, **options) + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Executes callable as transaction. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay, + ) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: loop = asyncio.get_running_loop() if new_state == CBState.HALF_OPEN: - asyncio.create_task(self._check_db_health(circuit.database)) + self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database)) return if old_state == CBState.CLOSED and new_state == CBState.OPEN: loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN \ No newline at end of file + circuit.state = CBState.HALF_OPEN + +class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + async def __aenter__(self: "Pipeline") -> "Pipeline": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + await self._client.__aexit__(exc_type, exc_value, traceback) + + def __await__(self): + return self._async_self().__await__() + + async def _async_self(self): + return self + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + async def reset(self) -> None: + self._command_stack = [] + + async def aclose(self) -> None: + """Close the pipeline""" + await self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" + return self.pipeline_execute_command(*args, **kwargs) + + async def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" + if not self._client.initialized: + await self._client.initialize() + + try: + return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + finally: + await self.reset() \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index af10a00988..4133dba394 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,6 +1,6 @@ from abc import abstractmethod from datetime import datetime -from typing import List, Optional, Callable, Any +from typing import List, Optional, Callable, Any, Union, Awaitable from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database @@ -13,6 +13,7 @@ from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.typing import KeyT class AsyncCommandExecutor(CommandExecutor): @@ -194,17 +195,30 @@ async def callback(): async def execute_pipeline(self, command_stack: tuple): async def callback(): - with self._active_database.client.pipeline() as pipe: + async with self._active_database.client.pipeline() as pipe: for command, options in command_stack: - await pipe.execute_command(*command, **options) + pipe.execute_command(*command, **options) return await pipe.execute() return await self._execute_with_failure_detection(callback, command_stack) - async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def execute_transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): async def callback(): - return await self._active_database.client.transaction(transaction, *watches, **options) + return await self._active_database.client.transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay + ) return await self._execute_with_failure_detection(callback) diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..5af2e3e864 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -0,0 +1,321 @@ +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +import pybreaker +import pytest + +from redis.asyncio.client import Pipeline +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF +from redis.asyncio.retry import Retry +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_pipe.__aexit__ = AsyncMock(return_value=None) + return mock_pipe + +class TestPipeline: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async with client.pipeline() as pipe: + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + pipe = mock_pipe() + pipe.execute.return_value = ['OK', 'value'] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ['OK1', 'value'] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ['OK2', 'value'] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + + await asyncio.sleep(0.15) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK2', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + +class TestTransaction: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + mock_db.client.transaction.return_value = ['OK', 'value'] + mock_db1.client.transaction.return_value = ['OK1', 'value'] + mock_db2.client.transaction.return_value = ['OK2', 'value'] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value'] + await asyncio.sleep(0.15) + assert await client.transaction(callback) == ['OK2', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 312712ba05..18bc8f1417 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -1,6 +1,7 @@ import os import pytest +import pytest_asyncio from redis.asyncio import Redis from redis.asyncio.multidb.client import MultiDBClient @@ -26,8 +27,8 @@ def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") return FaultInjectorClient(url) -@pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: +@pytest_asyncio.fixture() +async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -85,4 +86,5 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + async with MultiDBClient(config) as client: + return client, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 518e9561d9..f11bfafee3 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -4,6 +4,7 @@ import pytest +from redis.asyncio.client import Pipeline from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -33,6 +34,7 @@ def teardown_method(self, method): # Timeout so the cluster could recover from network failure. sleep(5) + @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", [{"failure_threshold": 2}], @@ -56,4 +58,132 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in # Execute commands until database failover while not listener.is_changed_flag: assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + while not listener.is_changed_flag: + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + async def callback(pipe: Pipeline): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + + # Client initialized on first transaction execution. + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute transaction before network failure + while not event.is_set(): + await r_multi_db.transaction(callback) + await asyncio.sleep(0.5) + + # Execute transaction until database failover + while not listener.is_changed_flag: + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] await asyncio.sleep(0.5) \ No newline at end of file From 4817a262b5c7125478b4b155a85e18b7ae5fd5d7 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 11 Sep 2025 09:53:35 +0300 Subject: [PATCH 12/50] Added pub/sub support for MultiDBClient (#3764) * Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Added pub/sub support for MultiDBClient * Added check for couroutines methods for pub/sub --- redis/asyncio/client.py | 12 +- redis/asyncio/multidb/client.py | 135 +++++++++++++++++- redis/asyncio/multidb/command_executor.py | 18 +-- redis/multidb/client.py | 8 +- .../test_scenario/test_active_active.py | 43 +++++- 5 files changed, 197 insertions(+), 19 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index aac409073f..4c000bd2e7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1191,6 +1191,7 @@ async def run( *, exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, + pubsub = None ) -> None: """Process pub/sub messages using registered callbacks. @@ -1215,9 +1216,14 @@ async def run( await self.connect() while True: try: - await self.get_message( - ignore_subscribe_messages=True, timeout=poll_timeout - ) + if pubsub is None: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + else: + await pubsub.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) except asyncio.CancelledError: raise except BaseException as e: diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 1025c4b37b..7c0bef4f6e 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,6 +1,7 @@ import asyncio from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable +from redis.asyncio.client import PubSubHandler from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector @@ -10,7 +11,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException -from redis.typing import KeyT +from redis.typing import KeyT, EncodableT, ChannelT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -222,6 +223,17 @@ async def transaction( watch_delay=watch_delay, ) + async def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + await self.initialize() + + return PubSub(self, **kwargs) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -340,4 +352,123 @@ async def execute(self) -> List[Any]: try: return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: - await self.reset() \ No newline at end of file + await self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + + self._client = client + self._client.command_executor.pubsub(**kwargs) + + async def __aenter__(self) -> "PubSub": + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.aclose() + + async def aclose(self): + return await self._client.command_executor.execute_pubsub_method('aclose') + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + async def execute_command(self, *args: EncodableT): + return await self._client.command_executor.execute_pubsub_method('execute_command', *args) + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'psubscribe', + *args, + **kwargs + ) + + async def punsubscribe(self, *args: ChannelT): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return await self._client.command_executor.execute_pubsub_method( + 'punsubscribe', + *args + ) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'subscribe', + *args, + **kwargs + ) + + async def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return await self._client.command_executor.execute_pubsub_method( + 'unsubscribe', + *args + ) + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number or None to wait indefinitely. + """ + return await self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + return await self._client.command_executor.execute_pubsub_run( + exception_handler=exception_handler, + sleep_time=poll_timeout, + pubsub=self + ) \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 4133dba394..7133955740 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from asyncio import iscoroutinefunction from datetime import datetime from typing import List, Optional, Callable, Any, Union, Awaitable @@ -178,14 +179,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy: def command_retry(self) -> Retry: return self._command_retry - async def pubsub(self, **kwargs): - async def callback(): - if self._active_pubsub is None: - self._active_pubsub = self._active_database.client.pubsub(**kwargs) - self._active_pubsub_kwargs = kwargs - return None - - return await self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs async def execute_command(self, *args, **options): async def callback(): @@ -225,7 +222,10 @@ async def callback(): async def execute_pubsub_method(self, method_name: str, *args, **kwargs): async def callback(): method = getattr(self.active_pubsub, method_name) - return await method(*args, **kwargs) + if iscoroutinefunction(method): + return await method(*args, **kwargs) + else: + return method(*args, **kwargs) return await self._execute_with_failure_detection(callback, *args) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 71e079346a..e6b815c76f 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs): def __enter__(self) -> "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.reset() - def __del__(self) -> None: try: # if this object went out of scope prior to shutting down @@ -350,7 +347,7 @@ def __del__(self) -> None: pass def reset(self) -> None: - pass + return self._client.command_executor.execute_pubsub_method('reset') def close(self) -> None: self.reset() @@ -359,6 +356,9 @@ def close(self) -> None: def subscribed(self) -> bool: return self._client.command_executor.active_pubsub.subscribed + def execute_command(self, *args): + return self._client.command_executor.execute_pubsub_method('execute_command', *args) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index f11bfafee3..3204d53fe9 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from time import sleep @@ -186,4 +187,44 @@ async def callback(pipe: Pipeline): # Execute transaction until database failover while not listener.is_changed_flag: await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) \ No newline at end of file + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + data = json.dumps({'message': 'test'}) + messages_count = 0 + + async def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = await r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + await pubsub.subscribe(**{'test-channel': handler}) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + + # Execute publish before network failure + while not event.is_set(): + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + task.cancel() + await pubsub.unsubscribe('test-channel') is True + assert messages_count > 1 \ No newline at end of file From 481d89edac9601904a33fcc90b20b0788bf5f12f Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:59:44 +0300 Subject: [PATCH 13/50] Added support for Lag-Aware Healthcheck and OSS Cluster API (#3768) * Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Added pub/sub support for MultiDBClient * Added check for couroutines methods for pub/sub * Added OSS Cluster API support for MultiDBCLient * Added support for Lag-Aware Healthcheck and OSS Cluster API * Increased timeouts between tests * Fixed space --- redis/asyncio/cluster.py | 4 +- redis/asyncio/http/__init__.py | 0 redis/asyncio/http/http_client.py | 216 +++++++++++++++ redis/asyncio/multidb/client.py | 4 + redis/asyncio/multidb/command_executor.py | 4 + redis/asyncio/multidb/healthcheck.py | 106 +++++++- redis/http/http_client.py | 10 +- redis/multidb/client.py | 4 + redis/multidb/healthcheck.py | 2 +- .../test_multidb/test_healthcheck.py | 143 +++++++++- tests/test_asyncio/test_scenario/conftest.py | 7 +- .../test_scenario/test_active_active.py | 245 ++++++++++-------- tests/test_multidb/test_healthcheck.py | 4 +- tests/test_scenario/test_active_active.py | 2 +- 14 files changed, 627 insertions(+), 124 deletions(-) create mode 100644 redis/asyncio/http/__init__.py create mode 100644 redis/asyncio/http/http_client.py diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 956262696a..f957baa319 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -404,6 +404,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes, require_full_coverage, @@ -2199,7 +2200,8 @@ async def _reinitialize_on_error(self, error): await self._pipe.cluster_client.nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._pipe.cluster_client.nodes_manager.update_moved_exception(error) + if type(error) == MovedError: + self._pipe.cluster_client.nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/asyncio/http/__init__.py b/redis/asyncio/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/http/http_client.py b/redis/asyncio/http/http_client.py new file mode 100644 index 0000000000..8f746b0a8b --- /dev/null +++ b/redis/asyncio/http/http_client.py @@ -0,0 +1,216 @@ +import asyncio +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Mapping, Union, Any +from redis.http.http_client import HttpResponse, HttpClient + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + +class AsyncHTTPClient(ABC): + @abstractmethod + async def get( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP GET request.""" + pass + + @abstractmethod + async def delete( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP DELETE request.""" + pass + + @abstractmethod + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP POST request.""" + pass + + @abstractmethod + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PUT request.""" + pass + + @abstractmethod + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PATCH request.""" + pass + + @abstractmethod + async def request( + self, + method: str, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + """ + Invoke HTTP request with given method.""" + pass + +class AsyncHTTPClientWrapper(AsyncHTTPClient): + """ + An async wrapper around sync HTTP client with thread pool execution. + """ + def __init__( + self, + client: HttpClient, + max_workers: int = 10 + ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + client: Sync HTTP client instance. + max_workers: Maximum number of concurrent requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.client = client + self._executor = ThreadPoolExecutor(max_workers=max_workers) + + async def get( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.get, + path, params, headers, timeout, expect_json + ) + + async def delete( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.delete, + path, params, headers, timeout, expect_json + ) + + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.post, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.put, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.patch, + path, json_body, data, params, headers, timeout, expect_json + ) + + async def request( + self, + method: str, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.request, + method, path, params, headers, body, timeout + ) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 7c0bef4f6e..e098a4723b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable from redis.asyncio.client import PubSubHandler @@ -13,6 +14,7 @@ from redis.multidb.exception import NoValidDatabaseException from redis.typing import KeyT, EncodableT, ChannelT +logger = logging.getLogger(__name__) class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): """ @@ -274,6 +276,8 @@ async def _check_db_health( database.circuit.state = CBState.OPEN is_healthy = False + logger.exception('Health check failed, due to exception', exc_info=e) + if on_error: await on_error(e) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 7133955740..d63b19269d 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import List, Optional, Callable, Any, Union, Awaitable +from redis.asyncio import RedisCluster from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ @@ -181,6 +182,9 @@ def command_retry(self) -> Retry: def pubsub(self, **kwargs): if self._active_pubsub is None: + if isinstance(self._active_database.client, RedisCluster): + raise ValueError("PubSub is not supported for RedisCluster") + self._active_pubsub = self._active_database.client.pubsub(**kwargs) self._active_pubsub_kwargs = kwargs diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index ccaf285ade..7e5f5a1ec7 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -1,9 +1,13 @@ import logging from abc import ABC, abstractmethod +from typing import Optional, Tuple, Union from redis.asyncio import Redis +from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT from redis.asyncio.retry import Retry +from redis.retry import Retry as SyncRetry from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient from redis.utils import dummy_fail_async DEFAULT_HEALTH_CHECK_RETRIES = 3 @@ -67,9 +71,107 @@ async def _returns_echoed_message(self, database) -> bool: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = await node.redis_connection.execute_command("ECHO", "healthcheck") + actual_message = await node.execute_command("ECHO", "healthcheck") if actual_message not in expected_message: return False - return True \ No newline at end of file + return True + +class LagAwareHealthCheck(AbstractHealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + def __init__( + self, + retry: SyncRetry = SyncRetry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + rest_api_port: int = 9443, + lag_aware_tolerance: int = 100, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + retry: Retry configuration for health checks + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + super().__init__( + retry=retry, + ) + self._http_client = AsyncHTTPClientWrapper( + HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=self.retry, + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password + ) + ) + self._rest_api_port = rest_api_port + self._lag_aware_tolerance = lag_aware_tolerance + + async def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in await self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint['dns_name'] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint['addr']: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") + await self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True \ No newline at end of file diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 0a2de2e44c..986e773915 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -68,7 +68,6 @@ class HttpClient: def __init__( self, base_url: str = "", - *, headers: Optional[Mapping[str, str]] = None, timeout: float = DEFAULT_TIMEOUT, retry: Retry = Retry( @@ -131,7 +130,6 @@ def __init__( def get( self, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -150,7 +148,6 @@ def get( def delete( self, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -169,7 +166,6 @@ def delete( def post( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -190,7 +186,6 @@ def post( def put( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -211,7 +206,6 @@ def put( def patch( self, path: str, - *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, @@ -234,7 +228,6 @@ def request( self, method: str, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, @@ -319,7 +312,6 @@ def _json_call( self, method: str, path: str, - *, params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, @@ -340,7 +332,7 @@ def _json_call( return resp.json() return resp - def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + def _prepare_body(self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: if json_body is not None and data is not None: raise ValueError("Provide either json_body or data, not both.") if json_body is not None: diff --git a/redis/multidb/client.py b/redis/multidb/client.py index e6b815c76f..7ed1935c1c 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,3 +1,4 @@ +import logging import threading from typing import List, Any, Callable, Optional @@ -11,6 +12,7 @@ from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck +logger = logging.getLogger(__name__) class MultiDBClient(RedisModuleCommands, CoreCommands): """ @@ -232,6 +234,8 @@ def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception database.circuit.state = CBState.OPEN is_healthy = False + logger.exception('Health check failed, due to exception', exc_info=e) + if on_error: on_error(e) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 9818d06e28..5a21918513 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -166,7 +166,7 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = (f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") self._http_client.get(url, expect_json=False) diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index fd5c8ec3f0..ba6e8c2b7c 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -1,10 +1,11 @@ import pytest -from mock.mock import AsyncMock +from mock.mock import AsyncMock, MagicMock from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff +from redis.http.http_client import HttpError from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -45,4 +46,140 @@ async def test_database_close_circuit_on_successful_healthcheck(self, mock_clien db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 \ No newline at end of file + assert mock_client.execute_command.call_count == 3 + +class TestLagAwareHealthCheck: + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = AsyncMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + rest_api_port=1234, lag_aware_tolerance=150 + ) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.client.base_url == f"https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert second_call.kwargs.get("expect_json") is False + + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = AsyncMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + + @pytest.mark.asyncio + async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, + {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + await hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + @pytest.mark.asyncio + async def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + await hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 18bc8f1417..735af7fed6 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -28,7 +28,7 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest_asyncio.fixture() -async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: +async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChangedListener, dict]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -44,6 +44,7 @@ async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChanged # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + health_checks = request.param.get('health_checks', []) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() event_dispatcher.register_listeners({ @@ -80,11 +81,11 @@ async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChanged databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, + health_checks=health_checks, health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - async with MultiDBClient(config) as client: - return client, listener, endpoint_config \ No newline at end of file + return config, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 3204d53fe9..c054d17dc2 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -1,11 +1,15 @@ import asyncio import json import logging +import os from time import sleep import pytest -from redis.asyncio.client import Pipeline +from redis.asyncio import RedisCluster +from redis.asyncio.client import Pipeline, Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -33,68 +37,101 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(5) + sleep(10) @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, config, event)) + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - # Client initialized on the first command. - await r_multi_db.set('key', 'value') + await r_multi_db.set('key', 'value') - # Execute commands before network failure - while not event.is_set(): - assert await r_multi_db.get('key') == 'value' - await asyncio.sleep(0.5) + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) - # Execute commands until database failover - while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' - await asyncio.sleep(0.5) + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2, "health_checks": + [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + }, + {"client_class": RedisCluster, "failure_threshold": 2, "health_checks": + [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + }, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) - async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db + async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + client_config, listener, endpoint_config = r_multi_db - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - # Client initialized on first pipe execution. - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await r_multi_db.set('key', 'value') - # Execute pipeline before network failure - while not event.is_set(): - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + client_config, listener, endpoint_config = r_multi_db + + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + + # Execute pipeline before network failure + while not event.is_set(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: @@ -111,37 +148,32 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db - - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) - - # Client initialized on first pipe execution. - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - # Execute pipeline before network failure - while not event.is_set(): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + client_config, listener, endpoint_config = r_multi_db + + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + + # Execute pipeline before network failure + while not event.is_set(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) # Execute pipeline until database failover while not listener.is_changed_flag: @@ -158,15 +190,16 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], indirect=True ) @pytest.mark.timeout(50) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db - - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + client_config, listener, endpoint_config = r_multi_db async def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') @@ -176,18 +209,19 @@ async def callback(pipe: Pipeline): pipe.get('{hash}key2') pipe.get('{hash}key3') - # Client initialized on first transaction execution. - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - # Execute transaction before network failure - while not event.is_set(): - await r_multi_db.transaction(callback) - await asyncio.sleep(0.5) + # Execute transaction before network failure + while not event.is_set(): + await r_multi_db.transaction(callback) + await asyncio.sleep(0.5) - # Execute transaction until database failover - while not listener.is_changed_flag: - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + # Execute transaction until database failover + while not listener.is_changed_flag: + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -197,10 +231,7 @@ async def callback(pipe: Pipeline): ) @pytest.mark.timeout(50) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): - r_multi_db, listener, config = r_multi_db - - event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + client_config, listener, endpoint_config = r_multi_db data = json.dumps({'message': 'test'}) messages_count = 0 @@ -209,22 +240,32 @@ async def handler(message): nonlocal messages_count messages_count += 1 - pubsub = await r_multi_db.pubsub() + async with MultiDBClient(client_config) as r_multi_db: + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - # Assign a handler and run in a separate thread. - await pubsub.subscribe(**{'test-channel': handler}) - task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + pubsub = await r_multi_db.pubsub() - # Execute publish before network failure - while not event.is_set(): - await r_multi_db.publish('test-channel', data) - await asyncio.sleep(0.5) + # Assign a handler and run in a separate thread. + await pubsub.subscribe(**{'test-channel': handler}) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) - # Execute publish until database failover - while not listener.is_changed_flag: - await r_multi_db.publish('test-channel', data) - await asyncio.sleep(0.5) + # Execute publish before network failure + while not event.is_set(): + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # After db changed still generates some traffic. + for _ in range(5): + await r_multi_db.publish('test-channel', data) - task.cancel() - await pubsub.unsubscribe('test-channel') is True - assert messages_count > 1 \ No newline at end of file + # A timeout to ensure that an async handler will handle all previous messages. + await asyncio.sleep(0.1) + task.cancel() + await pubsub.unsubscribe('test-channel') is True + assert messages_count >= 5 \ No newline at end of file diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 18bfe5f23b..77886832e7 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -88,7 +88,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" assert second_call.kwargs.get("expect_json") is False def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): @@ -120,7 +120,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 44c57e6b99..c87ad903b1 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -36,7 +36,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(5) + sleep(10) @pytest.mark.parametrize( "r_multi_db", From f81206bbddbad6108f0df6ab5c170e4553a5fbbd Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:11:43 +0300 Subject: [PATCH 14/50] Refactored Healthcheck and Failover strategy logic (#3771) * Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Added pub/sub support for MultiDBClient * Added check for couroutines methods for pub/sub * Added OSS Cluster API support for MultiDBCLient * Added support for Lag-Aware Healthcheck and OSS Cluster API * Increased timeouts between tests * [Sync] Refactored healthcheck * [Async] Refactored healthcheck * [Sync] Refactored Failover Strategy * [Async] Refactored Failover Strategy * Changed default values according to a design doc * [Async] Added Strategy Executor * [Sync] Added Strategy Executor * Apply comments --- redis/asyncio/multidb/client.py | 71 +++--- redis/asyncio/multidb/command_executor.py | 23 +- redis/asyncio/multidb/config.py | 45 ++-- redis/asyncio/multidb/failover.py | 105 +++++++-- redis/asyncio/multidb/failure_detector.py | 2 + redis/asyncio/multidb/healthcheck.py | 190 +++++++++++---- redis/multidb/circuit.py | 2 + redis/multidb/client.py | 70 +++--- redis/multidb/command_executor.py | 23 +- redis/multidb/config.py | 48 ++-- redis/multidb/exception.py | 12 + redis/multidb/failover.py | 108 +++++++-- redis/multidb/failure_detector.py | 6 +- redis/multidb/healthcheck.py | 193 +++++++++++---- tests/test_asyncio/test_multidb/conftest.py | 21 +- .../test_asyncio/test_multidb/test_client.py | 53 ++--- .../test_multidb/test_failover.py | 150 +++++++----- .../test_multidb/test_healthcheck.py | 207 ++++++++++++++-- .../test_multidb/test_pipeline.py | 37 ++- tests/test_asyncio/test_scenario/conftest.py | 3 +- .../test_scenario/test_active_active.py | 186 ++++++++++----- tests/test_multidb/conftest.py | 18 +- tests/test_multidb/test_client.py | 67 +++--- tests/test_multidb/test_command_executor.py | 15 +- tests/test_multidb/test_config.py | 4 +- tests/test_multidb/test_failover.py | 151 +++++++----- tests/test_multidb/test_healthcheck.py | 199 ++++++++++++++-- tests/test_multidb/test_pipeline.py | 40 ++-- tests/test_scenario/conftest.py | 11 +- tests/test_scenario/test_active_active.py | 220 ++++++++++++------ 30 files changed, 1577 insertions(+), 703 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index e098a4723b..b9925ea928 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -6,12 +6,12 @@ from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException from redis.typing import KeyT, EncodableT, ChannelT logger = logging.getLogger(__name__) @@ -29,6 +29,10 @@ def __init__(self, config: MultiDbConfig): self._health_checks.extend(config.health_checks) self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, + config.health_check_delay + ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: @@ -46,6 +50,8 @@ def __init__(self, config: MultiDbConfig): databases=self._databases, command_retry=self._command_retry, failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) @@ -244,42 +250,45 @@ async def _check_databases_health( Runs health checks as a recurring task. Runs health checks against all databases. """ - for database, _ in self._databases: - async with self._hc_lock: - await self._check_db_health(database, on_error) + results = await asyncio.wait_for( + asyncio.gather( + *( + asyncio.create_task(self._check_db_health(database)) + for database, _ in self._databases + ), + return_exceptions=True, + ), + timeout=self._health_check_interval, + ) - async def _check_db_health( - self, - database: AsyncDatabase, - on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, - ) -> None: + for result in results: + if isinstance(result, UnhealthyDatabaseException): + unhealthy_db = result.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + 'Health check failed, due to exception', + exc_info=result.original_exception + ) + + if on_error: + on_error(result.original_exception) + + async def _check_db_health(self, database: AsyncDatabase,) -> bool: """ Runs health checks on the given database until first failure. """ - is_healthy = True - # Health check will setup circuit state - for health_check in self._health_checks: - if not is_healthy: - # If one of the health checks failed, it's considered unhealthy - break + is_healthy = await self._health_check_policy.execute(self._health_checks, database) - try: - is_healthy = await health_check.check_health(database) + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED - if not is_healthy and database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - elif is_healthy and database.circuit.state != CBState.CLOSED: - database.circuit.state = CBState.CLOSED - except Exception as e: - if database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - is_healthy = False - - logger.exception('Health check failed, due to exception', exc_info=e) - - if on_error: - await on_error(e) + return is_healthy def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): loop = asyncio.get_running_loop() diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index d63b19269d..7e622d6260 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -8,7 +8,8 @@ from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ ResubscribeOnActiveDatabaseChanged -from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, FailoverStrategyExecutor, DefaultFailoverStrategyExecutor, \ + DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry @@ -62,8 +63,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def failover_strategy(self) -> AsyncFailoverStrategy: - """Returns failover strategy.""" + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + """Returns failover strategy executor.""" pass @property @@ -111,6 +112,8 @@ def __init__( command_retry: Retry, failover_strategy: AsyncFailoverStrategy, event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ @@ -122,6 +125,8 @@ def __init__( command_retry: Retry policy for failed command execution failover_strategy: Strategy for handling database failover event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ super().__init__(auto_fallback_interval) @@ -132,7 +137,11 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._failover_strategy = failover_strategy + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( + failover_strategy, + failover_attempts, + failover_delay + ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None @@ -173,8 +182,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def failover_strategy(self) -> AsyncFailoverStrategy: - return self._failover_strategy + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor @property def command_retry(self) -> Retry: @@ -265,7 +274,7 @@ async def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - await self.set_active_database(await self._failover_strategy.database()) + await self.set_active_database(await self._failover_strategy_executor.execute()) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index b5f4a0658d..354bbcf5c7 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -5,24 +5,20 @@ from redis.asyncio import ConnectionPool, Redis, RedisCluster from redis.asyncio.multidb.database import Databases, Database -from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy -from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper -from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ - EchoHealthCheck +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_DELAY, \ + DEFAULT_FAILOVER_ATTEMPTS +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper, \ + DEFAULT_FAILURES_THRESHOLD, DEFAULT_FAILURES_DURATION +from redis.asyncio.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY from redis.asyncio.retry import Retry -from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter, DEFAULT_GRACE_PERIOD from redis.multidb.failure_detector import CommandFailureDetector -DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_HEALTH_CHECK_INTERVAL = 5 -DEFAULT_FAILURES_THRESHOLD = 3 -DEFAULT_FAILURES_DURATION = 2 -DEFAULT_FAILOVER_RETRIES = 3 -DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) -DEFAULT_AUTO_FALLBACK_INTERVAL = -1 +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() @@ -78,11 +74,11 @@ class MultiDbConfig: failures_interval: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. - health_check_retries: Number of retry attempts for performing health checks. - health_check_backoff: Backoff strategy for health check retries. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. failover_strategy: Optional strategy for handling database failover scenarios. - failover_retries: Number of retries allowed for failover operations. - failover_backoff: Backoff strategy for failover retries. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. auto_fallback_interval: Time interval to trigger automatic fallback. event_dispatcher: Interface for dispatching events related to database operations. @@ -113,11 +109,12 @@ class MultiDbConfig: failures_interval: float = DEFAULT_FAILURES_DURATION health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL - health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES - health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[AsyncFailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_RETRIES - failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -160,10 +157,8 @@ def default_failure_detectors(self) -> List[AsyncFailureDetector]: def default_health_checks(self) -> List[HealthCheck]: return [ - EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + EchoHealthCheck(), ] def default_failover_strategy(self) -> AsyncFailoverStrategy: - return WeightBasedFailoverStrategy( - retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), - ) \ No newline at end of file + return WeightBasedFailoverStrategy() \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index a2ed427e05..997b7941c4 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -1,12 +1,13 @@ +import time from abc import abstractmethod, ABC from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.multidb.circuit import State as CBState -from redis.asyncio.retry import Retry from redis.data_structure import WeightedList -from redis.multidb.exception import NoValidDatabaseException -from redis.utils import dummy_fail_async +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 class AsyncFailoverStrategy(ABC): @@ -20,30 +21,98 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass +class FailoverStrategyExecutor(ABC): + + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> AsyncFailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + async def execute(self) -> AsyncDatabase: + """Execute the failover strategy.""" + pass + class WeightBasedFailoverStrategy(AsyncFailoverStrategy): """ Failover strategy based on database weights. """ - def __init__( - self, - retry: Retry - ): - self._retry = retry - self._retry.update_supported_errors([NoValidDatabaseException]) + def __init__(self): self._databases = WeightedList() async def database(self) -> AsyncDatabase: - return await self._retry.call_with_retry( - lambda: self._get_active_database(), - lambda _: dummy_fail_async() - ) + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') def set_databases(self, databases: Databases) -> None: self._databases = databases - async def _get_active_database(self) -> AsyncDatabase: - for database, _ in self._databases: - if database.circuit.state == CBState.CLOSED: - return database +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): + """ + Executes given failover strategy. + """ + def __init__( + self, + strategy: AsyncFailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + ): + self._strategy = strategy + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 + + @property + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> AsyncFailoverStrategy: + return self._strategy + + async def execute(self) -> AsyncDatabase: + try: + database = await self._strategy.database() + self._reset() + return database + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) - raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py index 8aa4752924..cdfcc6ff1e 100644 --- a/redis/asyncio/multidb/failure_detector.py +++ b/redis/asyncio/multidb/failure_detector.py @@ -2,6 +2,8 @@ from redis.multidb.failure_detector import FailureDetector +DEFAULT_FAILURES_THRESHOLD = 1000 +DEFAULT_FAILURES_DURATION = 2 class AsyncFailureDetector(ABC): diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index 7e5f5a1ec7..b5bf695380 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -1,67 +1,172 @@ +import asyncio import logging from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from enum import Enum +from typing import Optional, Tuple, Union, List + +from pygments.lexers.julia import allowed_variable from redis.asyncio import Redis from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT from redis.asyncio.retry import Retry -from redis.retry import Retry as SyncRetry -from redis.backoff import ExponentialWithJitterBackoff +from redis.backoff import NoBackoff from redis.http.http_client import HttpClient -from redis.utils import dummy_fail_async +from redis.multidb.exception import UnhealthyDatabaseException -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 logger = logging.getLogger(__name__) class HealthCheck(ABC): + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ @property @abstractmethod - def retry(self) -> Retry: - """The retry object to use for health checks.""" + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" pass + @property @abstractmethod - async def check_health(self, database) -> bool: - """Function to determine the health status.""" + def health_check_delay(self) -> float: + """Delay between health check probes.""" pass -class AbstractHealthCheck(HealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - self._retry = retry - self._retry.update_supported_errors([ConnectionRefusedError]) + @abstractmethod + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay + + @property + def health_check_probes(self) -> int: + return self._health_check_probes @property - def retry(self) -> Retry: - return self._retry + def health_check_delay(self) -> float: + return self._health_check_delay @abstractmethod - async def check_health(self, database) -> bool: + async def execute(self, health_checks: List[HealthCheck], database) -> bool: pass -class EchoHealthCheck(AbstractHealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - """ - Check database healthiness by sending an echo request. - """ - super().__init__( - retry=retry, - ) - async def check_health(self, database) -> bool: - return await self._retry.call_with_retry( - lambda: self._returns_echoed_message(database), - lambda _: dummy_fail_async() - ) +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + allowed_unsuccessful_probes = self.health_check_probes / 2 + else: + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + return False + except Exception as e: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False - async def _returns_echoed_message(self, database) -> bool: + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if await health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + async def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): @@ -78,16 +183,15 @@ async def _returns_echoed_message(self, database) -> bool: return True -class LagAwareHealthCheck(AbstractHealthCheck): +class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ def __init__( self, - retry: SyncRetry = SyncRetry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, - lag_aware_tolerance: int = 100, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, @@ -104,7 +208,6 @@ def __init__( Initialize LagAwareHealthCheck with the specified parameters. Args: - retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) @@ -117,14 +220,11 @@ def __init__( client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ - super().__init__( - retry=retry, - ) self._http_client = AsyncHTTPClientWrapper( HttpClient( timeout=timeout, auth_basic=auth_basic, - retry=self.retry, + retry=Retry(NoBackoff(), retries=0), verify_tls=verify_tls, ca_file=ca_file, ca_path=ca_path, diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 8f904c0e4b..5757f3e6d9 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -4,6 +4,8 @@ import pybreaker +DEFAULT_GRACE_PERIOD = 60 + class State(Enum): CLOSED = 'closed' OPEN = 'open' diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 7ed1935c1c..19f846bd29 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,5 +1,7 @@ import logging import threading +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler @@ -8,9 +10,9 @@ from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase -from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck +from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy logger = logging.getLogger(__name__) @@ -27,6 +29,10 @@ def __init__(self, config: MultiDbConfig): self._health_checks.extend(config.health_checks) self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, + config.health_check_delay + ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: @@ -44,6 +50,8 @@ def __init__(self, config: MultiDbConfig): databases=self._databases, command_retry=self._command_retry, failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, event_dispatcher=self._event_dispatcher, auto_fallback_interval=self._auto_fallback_interval, ) @@ -209,44 +217,48 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase) -> bool: """ Runs health checks on the given database until first failure. """ - is_healthy = True - - with self._hc_lock: - # Health check will setup circuit state - for health_check in self._health_checks: - if not is_healthy: - # If one of the health checks failed, it's considered unhealthy - break + # Health check will setup circuit state + is_healthy = self._health_check_policy.execute(self._health_checks, database) - try: - is_healthy = health_check.check_health(database) - - if not is_healthy and database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - elif is_healthy and database.circuit.state != CBState.CLOSED: - database.circuit.state = CBState.CLOSED - except Exception as e: - if database.circuit.state != CBState.OPEN: - database.circuit.state = CBState.OPEN - is_healthy = False - - logger.exception('Health check failed, due to exception', exc_info=e) - - if on_error: - on_error(e) + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + return is_healthy def _check_databases_health(self, on_error: Callable[[Exception], None] = None): """ Runs health checks as a recurring task. Runs health checks against all databases. """ - for database, _ in self._databases: - self._check_db_health(database, on_error) + with ThreadPoolExecutor(max_workers=len(self._databases)) as executor: + # Submit all health checks + futures = { + executor.submit(self._check_db_health, database) + for database, _ in self._databases + } + + for future in as_completed(futures, timeout=self._health_check_interval): + try: + future.result() + except UnhealthyDatabaseException as e: + unhealthy_db = e.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + 'Health check failed, due to exception', + exc_info=e.original_exception + ) + + if on_error: + on_error(e.original_exception) def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 364c0a07ea..7ca7d2ec52 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -8,7 +8,8 @@ from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged -from redis.multidb.failover import FailoverStrategy +from redis.multidb.failover import FailoverStrategy, FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ + DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -94,8 +95,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: @property @abstractmethod - def failover_strategy(self) -> FailoverStrategy: - """Returns failover strategy.""" + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + """Returns failover strategy executor.""" pass @property @@ -142,6 +143,8 @@ def __init__( command_retry: Retry, failover_strategy: FailoverStrategy, event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ @@ -153,6 +156,8 @@ def __init__( command_retry: Retry policy for failed command execution failover_strategy: Strategy for handling database failover event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ super().__init__(auto_fallback_interval) @@ -163,7 +168,11 @@ def __init__( self._databases = databases self._failure_detectors = failure_detectors self._command_retry = command_retry - self._failover_strategy = failover_strategy + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( + failover_strategy, + failover_attempts, + failover_delay + ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None @@ -209,8 +218,8 @@ def active_pubsub(self, pubsub: PubSub) -> None: self._active_pubsub = pubsub @property - def failover_strategy(self) -> FailoverStrategy: - return self._failover_strategy + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor def execute_command(self, *args, **options): def callback(): @@ -285,7 +294,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self.active_database = self._failover_strategy.database + self.active_database = self._failover_strategy_executor.execute() self._schedule_next_fallback() def _setup_event_dispatcher(self): diff --git a/redis/multidb/config.py b/redis/multidb/config.py index fc349ed04b..ff9872ffd4 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -5,25 +5,21 @@ from typing_extensions import Optional from redis import Redis, ConnectionPool -from redis.asyncio import RedisCluster -from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis import RedisCluster +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database, Databases -from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF -from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector, DEFAULT_FAILURES_THRESHOLD, \ + DEFAULT_FAILURES_DURATION +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY +from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, \ + DEFAULT_FAILOVER_DELAY from redis.retry import Retry -DEFAULT_GRACE_PERIOD = 5.0 -DEFAULT_HEALTH_CHECK_INTERVAL = 5 -DEFAULT_FAILURES_THRESHOLD = 3 -DEFAULT_FAILURES_DURATION = 2 -DEFAULT_FAILOVER_RETRIES = 3 -DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) -DEFAULT_AUTO_FALLBACK_INTERVAL = -1 +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() @@ -79,11 +75,12 @@ class MultiDbConfig: failures_interval: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. - health_check_retries: Number of retry attempts for performing health checks. - health_check_backoff: Backoff strategy for health check retries. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. + health_check_policy: Policy for determining database health based on health checks. failover_strategy: Optional strategy for handling database failover scenarios. - failover_retries: Number of retries allowed for failover operations. - failover_backoff: Backoff strategy for failover retries. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. auto_fallback_interval: Time interval to trigger automatic fallback. event_dispatcher: Interface for dispatching events related to database operations. @@ -114,11 +111,12 @@ class MultiDbConfig: failures_interval: float = DEFAULT_FAILURES_DURATION health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL - health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES - health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[FailoverStrategy] = None - failover_retries: int = DEFAULT_FAILOVER_RETRIES - failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) @@ -159,10 +157,8 @@ def default_failure_detectors(self) -> List[FailureDetector]: def default_health_checks(self) -> List[HealthCheck]: return [ - EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + EchoHealthCheck(), ] def default_failover_strategy(self) -> FailoverStrategy: - return WeightBasedFailoverStrategy( - retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), - ) + return WeightBasedFailoverStrategy() diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py index 80fdb9409a..f54632cae7 100644 --- a/redis/multidb/exception.py +++ b/redis/multidb/exception.py @@ -1,2 +1,14 @@ class NoValidDatabaseException(Exception): + pass + +class UnhealthyDatabaseException(Exception): + """Exception raised when a database is unhealthy due to an underlying exception.""" + + def __init__(self, message, database, original_exception): + super().__init__(message) + self.database = database + self.original_exception = original_exception + +class TemporaryUnavailableException(Exception): + """Exception raised when all databases in setup are temporary unavailable.""" pass \ No newline at end of file diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index fd08b77ecd..fbbd254252 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,16 +1,16 @@ +import time from abc import ABC, abstractmethod from redis.data_structure import WeightedList from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException -from redis.retry import Retry -from redis.utils import dummy_fail +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 class FailoverStrategy(ABC): - @property @abstractmethod def database(self) -> SyncDatabase: """Select the database according to the strategy.""" @@ -21,31 +21,99 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass +class FailoverStrategyExecutor(ABC): + + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> FailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + def execute(self) -> SyncDatabase: + """Execute the failover strategy.""" + pass + class WeightBasedFailoverStrategy(FailoverStrategy): """ Failover strategy based on database weights. """ - def __init__( - self, - retry: Retry - ): - self._retry = retry - self._retry.update_supported_errors([NoValidDatabaseException]) + def __init__(self) -> None: self._databases = WeightedList() - @property def database(self) -> SyncDatabase: - return self._retry.call_with_retry( - lambda: self._get_active_database(), - lambda _: dummy_fail() - ) - - def set_databases(self, databases: Databases) -> None: - self._databases = databases - - def _get_active_database(self) -> SyncDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database raise NoValidDatabaseException('No valid database available for communication') + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): + """ + Executes given failover strategy. + """ + def __init__( + self, + strategy: FailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + ): + self._strategy = strategy + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 + + @property + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> FailoverStrategy: + return self._strategy + + def execute(self) -> SyncDatabase: + try: + database = self._strategy.database() + self._reset() + return database + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 + diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index ef4bd35f69..6b918b152a 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -7,6 +7,8 @@ from redis.multidb.circuit import State as CBState +DEFAULT_FAILURES_THRESHOLD = 1000 +DEFAULT_FAILURES_DURATION = 2 class FailureDetector(ABC): @@ -26,8 +28,8 @@ class CommandFailureDetector(FailureDetector): """ def __init__( self, - threshold: int, - duration: float, + threshold: int = DEFAULT_FAILURES_THRESHOLD, + duration: float = DEFAULT_FAILURES_DURATION, error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 5a21918513..fcfd7e44a8 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,92 +1,197 @@ import logging from abc import abstractmethod, ABC -from typing import Optional, Tuple, Union +from enum import Enum +from time import sleep +from typing import Optional, Tuple, Union, List + +from pygments.lexers.julia import allowed_variable from redis import Redis -from redis.backoff import ExponentialWithJitterBackoff +from redis.backoff import NoBackoff from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient +from redis.multidb.exception import UnhealthyDatabaseException from redis.retry import Retry -from redis.utils import dummy_fail -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 logger = logging.getLogger(__name__) class HealthCheck(ABC): + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ @property @abstractmethod - def retry(self) -> Retry: - """The retry object to use for health checks.""" + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" pass + @property @abstractmethod - def check_health(self, database) -> bool: - """Function to determine the health status.""" + def health_check_delay(self) -> float: + """Delay between health check probes.""" pass -class AbstractHealthCheck(HealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - self._retry = retry - self._retry.update_supported_errors([ConnectionRefusedError]) + @abstractmethod + def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay + + @property + def health_check_probes(self) -> int: + return self._health_check_probes @property - def retry(self) -> Retry: - return self._retry + def health_check_delay(self) -> float: + return self._health_check_delay @abstractmethod - def check_health(self, database) -> bool: + def execute(self, health_checks: List[HealthCheck], database) -> bool: pass +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) -class EchoHealthCheck(AbstractHealthCheck): - def __init__( - self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) -> None: - """ - Check database healthiness by sending an echo request. - """ - super().__init__( - retry=retry, - ) - def check_health(self, database) -> bool: - return self._retry.call_with_retry( - lambda: self._returns_echoed_message(database), - lambda _: dummy_fail() - ) + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) - def _returns_echoed_message(self, database) -> bool: + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + allowed_unsuccessful_probes = self.health_check_probes / 2 + else: + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + return False + except Exception as e: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False + + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + f"Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): - actual_message = database.client.execute_command("ECHO" ,"healthcheck") + actual_message = database.client.execute_command("ECHO", "healthcheck") return actual_message in expected_message else: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = node.redis_connection.execute_command("ECHO" ,"healthcheck") + actual_message = node.redis_connection.execute_command("ECHO", "healthcheck") if actual_message not in expected_message: return False return True -class LagAwareHealthCheck(AbstractHealthCheck): + +class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ def __init__( self, - retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, - lag_aware_tolerance: int = 100, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, @@ -103,7 +208,6 @@ def __init__( Initialize LagAwareHealthCheck with the specified parameters. Args: - retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) @@ -116,13 +220,10 @@ def __init__( client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ - super().__init__( - retry=retry, - ) self._http_client = HttpClient( timeout=timeout, auth_basic=auth_basic, - retry=self.retry, + retry=Retry(NoBackoff(), retries=0), verify_tls=verify_tls, ca_file=ca_file, ca_path=ca_path, diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 0ac231cf52..f5ea12d9b0 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,17 +2,16 @@ import pytest -from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ - DatabaseConfig +from redis.asyncio.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_AUTO_FALLBACK_INTERVAL from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_INTERVAL, \ + DEFAULT_HEALTH_CHECK_POLICY from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis from redis.asyncio.multidb.database import Database, Databases - @pytest.fixture() def mock_client() -> Redis: return Mock(spec=Redis) @@ -79,18 +78,18 @@ def mock_db2(request) -> Database: def mock_multi_db_config( request, mock_fd, mock_fs, mock_hc, mock_ed ) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', None) - if hc_interval is None: - hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL - - auto_fallback_interval = request.param.get('auto_fallback_interval', None) - if auto_fallback_interval is None: - auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) + health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) + health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, event_dispatcher=mock_ed diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index c2fe914e9f..e2ebb89bca 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -5,13 +5,10 @@ import pytest from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF from redis.asyncio.multidb.database import AsyncDatabase from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck -from redis.asyncio.retry import Retry +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck from redis.event import EventDispatcher, AsyncOnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.exception import NoValidDatabaseException @@ -46,7 +43,7 @@ async def test_execute_command_against_correct_db_on_successful_initialization( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -74,12 +71,12 @@ async def test_execute_command_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command = AsyncMock(return_value='OK1') - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -90,7 +87,7 @@ async def test_execute_command_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -116,16 +113,12 @@ async def test_execute_command_against_correct_db_on_background_health_check_det databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set('key', 'value') == 'OK1' @@ -141,7 +134,7 @@ async def test_execute_command_against_correct_db_on_background_health_check_det 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -155,17 +148,13 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set('key', 'value') == 'OK1' @@ -201,7 +190,7 @@ async def test_execute_command_throws_exception_on_failed_initialization( with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): await client.set('key', 'value') - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -230,7 +219,7 @@ async def test_add_database_throws_exception_on_same_database( with pytest.raises(ValueError, match='Given database already exists'): await client.add_database(mock_db) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -261,10 +250,10 @@ async def test_add_database_makes_new_database_active( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK2' - assert mock_hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 6 await client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert await client.set('key', 'value') == 'OK1' @@ -297,7 +286,7 @@ async def test_remove_highest_weighted_database( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.remove_database(mock_db1) assert await client.set('key', 'value') == 'OK2' @@ -331,7 +320,7 @@ async def test_update_database_weight_to_be_highest( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 @@ -373,7 +362,7 @@ async def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -418,7 +407,7 @@ async def test_add_new_health_check( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -426,8 +415,8 @@ async def test_add_new_health_check( await client.add_health_check(another_hc) await client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 4 - assert another_hc.check_health.call_count == 1 + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -457,7 +446,7 @@ async def test_set_active_database( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert await client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 await client.set_active_database(mock_db) assert await client.set('key', 'value') == 'OK' diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index f692c40643..0275969d03 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -1,13 +1,11 @@ -from unittest.mock import PropertyMock +import asyncio import pytest -from redis.backoff import NoBackoff, ExponentialBackoff from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException -from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy -from redis.asyncio.retry import Retry +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor class TestAsyncWeightBasedFailoverStrategy: @@ -30,13 +28,12 @@ class TestAsyncWeightBasedFailoverStrategy: indirect=True, ) async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - strategy = WeightBasedFailoverStrategy(retry=retry) + strategy = WeightBasedFailoverStrategy() strategy.set_databases(databases) assert await strategy.database() == mock_db1 @@ -53,69 +50,106 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] - ) - type(mock_db.circuit).state = state_mock - - retry = Retry(ExponentialBackoff(cap=1), 3) - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.set_databases(databases) + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + failover_strategy = WeightBasedFailoverStrategy() - assert await failover_strategy.database() == mock_db - assert state_mock.call_count == 4 + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database() +class TestDefaultStrategyExecutor: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + 'mock_db', [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, ], indirect=True, ) - async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + async def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 ) - type(mock_db.circuit).state = state_mock - retry = Retry(ExponentialBackoff(cap=1), 3) - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.set_databases(databases) + for i in range(failover_attempts + 1): + try: + database = await executor.execute() + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database() + assert mock_fs.database.call_count == 4 + + @pytest.mark.asyncio + async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 + ) + + with pytest.raises(NoValidDatabaseException): + for i in range(failover_attempts + 1): + try: + await executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass - assert state_mock.call_count == 4 + assert mock_fs.database.call_count == 4 @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + async def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 + ) - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database() \ No newline at end of file + with pytest.raises(TemporaryUnavailableException, match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + )): + for i in range(failover_attempts + 1): + try: + await executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e + + assert mock_fs.database.call_count == 4 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index ba6e8c2b7c..72da0ef737 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -1,15 +1,181 @@ import pytest -from mock.mock import AsyncMock, MagicMock +from mock.mock import AsyncMock, Mock from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck -from redis.asyncio.retry import Retry -from redis.backoff import ExponentialBackoff +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ + HealthyMajorityPolicy, HealthyAnyPolicy from redis.http.http_client import HttpError from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError +from redis.multidb.exception import UnhealthyDatabaseException +class TestHealthyAllPolicy: + @pytest.mark.asyncio + async def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + @pytest.mark.asyncio + async def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + +class TestHealthyMajorityPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', + 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', + 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', + 'HC1 + HC2 - majority - even' + ] + ) + async def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), + (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ], + ids=[ + 'HC1 - majority- odd', 'HC2 - majority - odd', + 'HC1 - majority - even', 'HC2 - majority - even', + ] + ) + async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + +class TestHealthyAnyPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + 'HC1 - no successful', 'HC2 - no successful', + 'HC1 - successful', 'HC2 - successful', + ] + ) + async def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + class TestEchoHealthCheck: @pytest.mark.asyncio @@ -18,12 +184,12 @@ async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): @@ -31,22 +197,22 @@ async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_clien Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command = AsyncMock(side_effect=['wrong']) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == False - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) mock_cb.state = CBState.HALF_OPEN - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert await hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 class TestLagAwareHealthCheck: @pytest.mark.asyncio @@ -75,7 +241,6 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_clien ] hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), rest_api_port=1234, lag_aware_tolerance=150 ) # Inject our mocked http client @@ -115,16 +280,14 @@ async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, m None, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") assert await hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" @pytest.mark.asyncio async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): @@ -141,9 +304,7 @@ async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_c {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -170,9 +331,7 @@ async def test_propagates_http_error_from_availability(self, mock_client, mock_c HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 5af2e3e864..492919cdac 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -6,13 +6,10 @@ from redis.asyncio.client import Pipeline from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF +from redis.asyncio.multidb.healthcheck import EchoHealthCheck from redis.asyncio.retry import Retry from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF from tests.test_asyncio.test_multidb.conftest import create_weighted_list @@ -57,7 +54,7 @@ async def test_executes_pipeline_against_correct_db( pipe.get('key1') assert await pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -83,7 +80,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -93,7 +90,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert await pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -104,7 +101,7 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -130,9 +127,7 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -150,9 +145,7 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) @@ -219,7 +212,7 @@ async def callback(pipe: Pipeline): pipe.get('key1') assert await client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -243,7 +236,7 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -253,7 +246,7 @@ async def callback(pipe: Pipeline): pipe.get('key1') assert await client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -264,7 +257,7 @@ async def callback(pipe: Pipeline): 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -290,9 +283,7 @@ async def test_execute_transaction_against_correct_db_on_background_health_check databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -302,9 +293,7 @@ async def test_execute_transaction_against_correct_db_on_background_health_check mock_db2.client.transaction.return_value = ['OK2', 'value'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 735af7fed6..97e66f715b 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -82,10 +82,9 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged command_retry=command_retry, failure_threshold=failure_threshold, health_checks=health_checks, - health_check_retries=3, + health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, - health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) return config, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index c054d17dc2..2540c8a99d 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -9,7 +9,12 @@ from redis.asyncio import RedisCluster from redis.asyncio.client import Pipeline, Redis from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ConstantBackoff +from redis.multidb.exception import TemporaryUnavailableException +from redis.utils import dummy_fail_async from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(10) + sleep(15) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -49,24 +54,40 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - await r_multi_db.set('key', 'value') + await retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _: dummy_fail_async() + ) # Execute commands before network failure while not event.is_set(): - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key') , + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -83,24 +104,38 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) - await r_multi_db.set('key', 'value') + await retry.call_with_retry( + lambda: r_multi_db.set('key', 'value'), + lambda _: dummy_fail_async() + ) # Execute commands before network failure while not event.is_set(): - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert await r_multi_db.get('key') == 'value' + assert await retry.call_with_retry( + lambda: r_multi_db.get('key'), + lambda _: dummy_fail_async() + ) == 'value' await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -113,9 +148,24 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + + async def callback(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() @@ -123,26 +173,18 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, # Execute pipeline before network failure while not event.is_set(): - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -155,9 +197,24 @@ async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + + async def callback(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] async with MultiDBClient(client_config) as r_multi_db: event = asyncio.Event() @@ -165,27 +222,19 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ # Execute pipeline before network failure while not event.is_set(): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) # Execute pipeline until database failover while not listener.is_changed_flag: - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -197,10 +246,16 @@ async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_ ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + async def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -215,12 +270,18 @@ async def callback(pipe: Pipeline): # Execute transaction before network failure while not event.is_set(): - await r_multi_db.transaction(callback) + await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # Execute transaction until database failover while not listener.is_changed_flag: - await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + assert await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async() + ) == [True, True, True, 'value1', 'value2', 'value3'] await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -229,9 +290,14 @@ async def callback(pipe: Pipeline): [{"failure_threshold": 2}], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(60) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): client_config, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) data = json.dumps({'message': 'test'}) messages_count = 0 @@ -247,22 +313,34 @@ async def handler(message): pubsub = await r_multi_db.pubsub() # Assign a handler and run in a separate thread. - await pubsub.subscribe(**{'test-channel': handler}) + await retry.call_with_retry( + lambda: pubsub.subscribe(**{'test-channel': handler}), + lambda _: dummy_fail_async() + ) task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) # Execute publish before network failure while not event.is_set(): - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) await asyncio.sleep(0.5) # After db changed still generates some traffic. for _ in range(5): - await r_multi_db.publish('test-channel', data) + await retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail_async() + ) # A timeout to ensure that an async handler will handle all previous messages. await asyncio.sleep(0.1) diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 0c082f0f17..3b1f7f369b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -6,11 +6,11 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL + DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck +from redis.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_POLICY from tests.conftest import mock_ed @@ -80,18 +80,18 @@ def mock_db2(request) -> Database: def mock_multi_db_config( request, mock_fd, mock_fs, mock_hc, mock_ed ) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', None) - if hc_interval is None: - hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL - - auto_fallback_interval = request.param.get('auto_fallback_interval', None) - if auto_fallback_interval is None: - auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) + health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) + health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, event_dispatcher=mock_ed diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index d352c1da92..5e710f23c2 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -6,16 +6,12 @@ from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ - DEFAULT_HEALTH_CHECK_BACKOFF -from redis.retry import Retry +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -46,7 +42,7 @@ def test_execute_command_against_correct_db_on_successful_initialization( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -73,12 +69,12 @@ def test_execute_command_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -88,7 +84,7 @@ def test_execute_command_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -114,31 +110,28 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] - mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.health_check_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) + sleep(0.3) assert client.set('key', 'value') == 'OK2' - sleep(0.1) + sleep(0.2) assert client.set('key', 'value') == 'OK' - sleep(0.1) + sleep(0.2) assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -152,23 +145,19 @@ def test_execute_command_auto_fallback_to_highest_weight_db( databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] - mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.2 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.health_check_interval = 0.2 + mock_multi_db_config.auto_fallback_interval = 0.4 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) + sleep(0.30) assert client.set('key', 'value') == 'OK2' - sleep(0.22) + sleep(0.44) assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( @@ -256,10 +245,10 @@ def test_add_database_makes_new_database_active( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK2' - assert mock_hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 6 client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 assert client.set('key', 'value') == 'OK1' @@ -291,7 +280,7 @@ def test_remove_highest_weighted_database( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.remove_database(mock_db1) @@ -325,7 +314,7 @@ def test_update_database_weight_to_be_highest( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 @@ -366,7 +355,7 @@ def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -410,7 +399,7 @@ def test_add_new_health_check( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -418,8 +407,8 @@ def test_add_new_health_check( client.add_health_check(another_hc) client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 4 - assert another_hc.check_health.call_count == 1 + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -448,7 +437,7 @@ def test_set_active_database( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 client.set_active_database(mock_db) assert client.set('key', 'value') == 'OK' diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 675f9d442f..044fef0f8c 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -61,8 +61,7 @@ def test_execute_command_automatically_select_active_database( ): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2] databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -78,7 +77,7 @@ def test_execute_command_automatically_select_active_database( assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 - assert mock_selector.call_count == 2 + assert mock_fs.database.call_count == 2 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -96,8 +95,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( ): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -119,7 +117,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert mock_ed.register_listeners.call_count == 1 - assert mock_selector.call_count == 3 + assert mock_fs.database.call_count == 3 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -137,8 +135,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( ): mock_db1.client.execute_command.side_effect = ['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1'] mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] - mock_selector = PropertyMock(side_effect=[mock_db1, mock_db2, mock_db1]) - type(mock_fs).database = mock_selector + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] threshold = 3 fd = CommandFailureDetector(threshold, 1) ed = EventDispatcher() @@ -157,4 +154,4 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert executor.execute_command('SET', 'key', 'value') == 'OK1' - assert mock_selector.call_count == 3 \ No newline at end of file + assert mock_fs.database.call_count == 3 \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 1ea63a0e14..abed8ec2fa 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,8 +1,8 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD + DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 06390c4e2e..6ae6a9610c 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -1,13 +1,11 @@ -from unittest.mock import PropertyMock +from time import sleep import pytest -from redis.backoff import NoBackoff, ExponentialBackoff from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException -from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.retry import Retry +from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +from redis.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor class TestWeightBasedFailoverStrategy: @@ -29,16 +27,15 @@ class TestWeightBasedFailoverStrategy: indirect=True, ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) databases = WeightedList() databases.add(mock_db, mock_db.weight) databases.add(mock_db1, mock_db1.weight) databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy = WeightBasedFailoverStrategy() failover_strategy.set_databases(databases) - assert failover_strategy.database == mock_db1 + assert failover_strategy.database() == mock_db1 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -51,67 +48,103 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): ], indirect=True, ) - def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] - ) - type(mock_db.circuit).state = state_mock - - retry = Retry(ExponentialBackoff(cap=1), 3) - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.set_databases(databases) + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + failover_strategy = WeightBasedFailoverStrategy() - assert failover_strategy.database == mock_db - assert state_mock.call_count == 4 + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert failover_strategy.database() +class TestDefaultStrategyExecutor: @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + 'mock_db', [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, ], indirect=True, ) - def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): - state_mock = PropertyMock( - side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 ) - type(mock_db.circuit).state = state_mock - retry = Retry(ExponentialBackoff(cap=1), 3) - databases = WeightedList() - databases.add(mock_db, mock_db.weight) - databases.add(mock_db1, mock_db1.weight) - databases.add(mock_db2, mock_db2.weight) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) - failover_strategy.set_databases(databases) + for i in range(failover_attempts + 1): + try: + database = executor.execute() + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert failover_strategy.database + assert mock_fs.database.call_count == 4 - assert state_mock.call_count == 4 + def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 + ) - @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', - [ - ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, - ), - ], - indirect=True, - ) - def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): - retry = Retry(NoBackoff(), 0) - failover_strategy = WeightBasedFailoverStrategy(retry=retry) + with pytest.raises(NoValidDatabaseException): + for i in range(failover_attempts + 1): + try: + executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert failover_strategy.database \ No newline at end of file + assert mock_fs.database.call_count == 4 + + def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, + failover_attempts=failover_attempts, + failover_delay=0.1 + ) + + with pytest.raises(TemporaryUnavailableException, match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + )): + for i in range(failover_attempts + 1): + try: + executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e + + assert mock_fs.database.call_count == 4 \ No newline at end of file diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 77886832e7..43ad1ac888 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,15 +1,171 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest -from redis.backoff import ExponentialBackoff from redis.multidb.database import Database from redis.http.http_client import HttpError -from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ + UnhealthyDatabaseException, HealthyMajorityPolicy, HealthyAnyPolicy from redis.multidb.circuit import State as CBState -from redis.exceptions import ConnectionError -from redis.retry import Retry +class TestHealthyAllPolicy: + def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + +class TestHealthyMajorityPolicy: + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', + 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', + 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', + 'HC1 + HC2 - majority - even' + ] + ) + def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), + (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ], + ids=[ + 'HC1 - majority- odd', 'HC2 - majority - odd', + 'HC1 - majority - even', 'HC2 - majority - even', + ] + ) + def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + +class TestHealthyAnyPolicy: + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + 'HC1 - no successful', 'HC2 - no successful', + 'HC1 - successful', 'HC2 - successful', + ] + ) + def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 class TestEchoHealthCheck: def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): @@ -17,33 +173,33 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command.return_value = 'healthcheck' + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + mock_client.execute_command.return_value = 'wrong' + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] + mock_client.execute_command.return_value = 'healthcheck' mock_cb.state = CBState.HALF_OPEN - hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 + assert mock_client.execute_command.call_count == 1 class TestLagAwareHealthCheck: @@ -72,7 +228,6 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc ] hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), rest_api_port=1234, lag_aware_tolerance=150 ) # Inject our mocked http client @@ -111,16 +266,14 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb None, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=100" + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ @@ -136,9 +289,7 @@ def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") @@ -164,9 +315,7 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb): HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), ] - hc = LagAwareHealthCheck( - retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), - ) + hc = LagAwareHealthCheck() hc._http_client = mock_http db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 6e7c344d85..54f6a4df17 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -7,11 +7,8 @@ from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF -from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF -from redis.retry import Retry +from redis.multidb.failover import WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.multidb.healthcheck import EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list def mock_pipe() -> Pipeline: @@ -54,7 +51,7 @@ def test_executes_pipeline_against_correct_db( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -79,7 +76,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -89,7 +86,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -99,7 +96,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -125,9 +122,8 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -145,9 +141,7 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) @@ -214,7 +208,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -237,7 +231,7 @@ def test_execute_transaction_against_correct_db_and_closed_circuit( patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - mock_hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -247,7 +241,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - assert mock_hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -257,7 +251,7 @@ def callback(pipe: Pipeline): 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ ( - {}, + {"health_check_probes" : 1}, {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, @@ -283,9 +277,7 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - )]): + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -295,9 +287,7 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_db2.client.transaction.return_value = ['OK2', 'value'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( - retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) - ) + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index a0f19e1a87..d9c1c17aff 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -9,10 +9,10 @@ from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient -from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_FAILURES_THRESHOLD +from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL from redis.multidb.event import ActiveDatabaseChanged -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.failure_detector import DEFAULT_FAILURES_THRESHOLD +from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient @@ -61,6 +61,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + health_check_delay = request.param.get('health_check_delay', DEFAULT_HEALTH_CHECK_DELAY) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() event_dispatcher.register_listeners({ @@ -97,10 +98,10 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, - health_check_retries=3, + health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, - health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + health_check_delay=health_check_delay, ) return MultiDBClient(config), listener, endpoint_config diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index c87ad903b1..a3056323a5 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -7,8 +7,13 @@ import pytest from redis import Redis, RedisCluster +from redis.backoff import ConstantBackoff from redis.client import Pipeline +from redis.multidb.exception import TemporaryUnavailableException +from redis.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.multidb.healthcheck import LagAwareHealthCheck +from redis.retry import Retry +from redis.utils import dummy_fail from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -36,7 +41,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(10) + sleep(15) @pytest.mark.parametrize( "r_multi_db", @@ -47,10 +52,17 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, @@ -59,31 +71,45 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector ) # Client initialized on the first command. - r_multi_db.set('key', 'value') + retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _ : dummy_fail() + ) thread.start() # Execute commands before network failure while not event.is_set(): - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "failure_threshold": 2, "health_check_interval": 10}, + {"client_class": RedisCluster, "failure_threshold": 2, "health_check_interval": 10}, ], ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -101,17 +127,26 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj ) # Client initialized on the first command. - r_multi_db.set('key', 'value') + retry.call_with_retry( + lambda : r_multi_db.set('key', 'value'), + lambda _ : dummy_fail() + ) thread.start() # Execute commands before network failure while not event.is_set(): - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert r_multi_db.get('key') == 'value' + assert retry.call_with_retry( + lambda : r_multi_db.get('key'), + lambda _ : dummy_fail() + ) == 'value' sleep(0.5) @pytest.mark.parametrize( @@ -123,9 +158,14 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -134,20 +174,7 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault args=(fault_injector_client,config,event) ) - # Client initialized on first pipe execution. - with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - thread.start() - - # Execute pipeline before network failure - while not event.is_set(): + def callback(): with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -156,18 +183,28 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Client initialized on first pipe execution. + retry.call_with_retry( + lambda : callback(), + lambda _ : dummy_fail() + ) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -179,9 +216,14 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -190,20 +232,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject args=(fault_injector_client,config,event) ) - # Client initialized on first pipe execution. - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - - thread.start() - - # Execute pipeline before network failure - while not event.is_set(): + def callback(): pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -212,18 +241,29 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Client initialized on first pipe execution. + retry.call_with_retry( + lambda : callback(), + lambda _ : dummy_fail() + ) + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -235,9 +275,14 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -255,17 +300,26 @@ def callback(pipe: Pipeline): pipe.get('{hash}key3') # Client initialized on first transaction execution. - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda : r_multi_db.transaction(callback), + lambda _ : dummy_fail() + ) thread.start() # Execute transaction before network failure while not event.is_set(): - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail() + ) sleep(0.5) # Execute transaction until database failover while not listener.is_changed_flag: - r_multi_db.transaction(callback) + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail() + ) sleep(0.5) @pytest.mark.parametrize( @@ -277,9 +331,14 @@ def callback(pipe: Pipeline): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -297,18 +356,27 @@ def handler(message): pubsub = r_multi_db.pubsub() # Assign a handler and run in a separate thread. - pubsub.subscribe(**{'test-channel': handler}) + retry.call_with_retry( + lambda: pubsub.subscribe(**{'test-channel': handler}), + lambda _: dummy_fail() + ) pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) thread.start() # Execute publish before network failure while not event.is_set(): - r_multi_db.publish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - r_multi_db.publish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.publish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) pubsub_thread.stop() @@ -323,9 +391,14 @@ def handler(message): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(50) + @pytest.mark.timeout(100) def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + ) event = threading.Event() thread = threading.Thread( @@ -343,7 +416,10 @@ def handler(message): pubsub = r_multi_db.pubsub() # Assign a handler and run in a separate thread. - pubsub.ssubscribe(**{'test-channel': handler}) + retry.call_with_retry( + lambda: pubsub.ssubscribe(**{'test-channel': handler}), + lambda _: dummy_fail() + ) pubsub_thread = pubsub.run_in_thread( sleep_time=0.1, daemon=True, @@ -353,12 +429,18 @@ def handler(message): # Execute publish before network failure while not event.is_set(): - r_multi_db.spublish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.spublish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: - r_multi_db.spublish('test-channel', data) + retry.call_with_retry( + lambda: r_multi_db.spublish('test-channel', data), + lambda _: dummy_fail() + ) sleep(0.5) pubsub_thread.stop() From 457a35c1962ff07c73a16751029e2e20cefa5d67 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 19 Sep 2025 11:14:13 +0300 Subject: [PATCH 15/50] Removed redundant dependency --- redis/multidb/healthcheck.py | 2 -- tests/test_scenario/conftest.py | 7 +++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index fcfd7e44a8..81bbec6e17 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -4,8 +4,6 @@ from time import sleep from typing import Optional, Tuple, Union, List -from pygments.lexers.julia import allowed_variable - from redis import Redis from redis.backoff import NoBackoff from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index ad49ce51af..c9f75948b3 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -40,8 +40,7 @@ def endpoint_name(request): ) -@pytest.fixture() -def endpoints_config(endpoint_name: str): +def get_endpoints_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) if not (endpoints_config and os.path.exists(endpoints_config)): @@ -69,9 +68,9 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen client_class = request.param.get('client_class', Redis) if client_class == Redis: - endpoint_config = endpoints_config('re-active-active') + endpoint_config = get_endpoints_config('re-active-active') else: - endpoint_config = endpoints_config('re-active-active-oss-cluster') + endpoint_config = get_endpoints_config('re-active-active-oss-cluster') username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) From c21da4b0d9efa1ca00b530cc044e97fe40c2dd95 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 19 Sep 2025 11:53:26 +0300 Subject: [PATCH 16/50] Fixed async tests --- tests/test_asyncio/test_scenario/conftest.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 97e66f715b..e7302e2d60 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -4,14 +4,13 @@ import pytest_asyncio from redis.asyncio import Redis -from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ MultiDbConfig from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff from redis.event import AsyncEventListenerInterface, EventDispatcher -from tests.test_scenario.conftest import get_endpoint_config, extract_cluster_fqdn +from tests.test_scenario.conftest import get_endpoints_config, extract_cluster_fqdn from tests.test_scenario.fault_injector_client import FaultInjectorClient @@ -32,9 +31,9 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged client_class = request.param.get('client_class', Redis) if client_class == Redis: - endpoint_config = get_endpoint_config('re-active-active') + endpoint_config = get_endpoints_config('re-active-active') else: - endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + endpoint_config = get_endpoints_config('re-active-active-oss-cluster') username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) From 821cc543e8c7cf296cf36c90fb254737978274d7 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 19 Sep 2025 12:29:04 +0300 Subject: [PATCH 17/50] Increased lag-aware tolerance --- .../test_scenario/test_active_active.py | 16 ++++++++++++++-- tests/test_scenario/test_active_active.py | 6 +++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 2540c8a99d..f99bef926d 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -95,10 +95,22 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in "r_multi_db", [ {"client_class": Redis, "failure_threshold": 2, "health_checks": - [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + [ + LagAwareHealthCheck( + verify_tls=False, + auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')), + lag_aware_tolerance=10000 + ) + ] }, {"client_class": RedisCluster, "failure_threshold": 2, "health_checks": - [LagAwareHealthCheck(verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')))] + [ + LagAwareHealthCheck( + verify_tls=False, + auth_basic=(os.getenv('ENV0_USERNAME'), os.getenv('ENV0_PASSWORD')), + lag_aware_tolerance=10000 + ) + ] }, ], ids=["standalone", "cluster"], diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index a3056323a5..f4eb8efe59 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -123,7 +123,11 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj # Adding additional health check to the client. r_multi_db.add_health_check( - LagAwareHealthCheck(verify_tls=False, auth_basic=(env0_username,env0_password)) + LagAwareHealthCheck( + verify_tls=False, + auth_basic=(env0_username,env0_password), + lag_aware_tolerance=10000 + ) ) # Client initialized on the first command. From a7540c15ca50872313696882c3a1e2f44332edfc Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 19 Sep 2025 13:11:33 +0300 Subject: [PATCH 18/50] Fixed typing issue, increase health_check_interval, added timeout handling --- redis/asyncio/multidb/client.py | 23 ++++++++------ redis/asyncio/multidb/healthcheck.py | 3 +- redis/http/http_client.py | 2 +- redis/multidb/client.py | 31 ++++++++++--------- .../test_scenario/test_active_active.py | 8 ++--- tests/test_scenario/test_active_active.py | 4 +-- 6 files changed, 39 insertions(+), 32 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index b9925ea928..6d08e096d3 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -250,16 +250,21 @@ async def _check_databases_health( Runs health checks as a recurring task. Runs health checks against all databases. """ - results = await asyncio.wait_for( - asyncio.gather( - *( - asyncio.create_task(self._check_db_health(database)) - for database, _ in self._databases + try: + results = await asyncio.wait_for( + asyncio.gather( + *( + asyncio.create_task(self._check_db_health(database)) + for database, _ in self._databases + ), + return_exceptions=True, ), - return_exceptions=True, - ), - timeout=self._health_check_interval, - ) + timeout=self._health_check_interval, + ) + except asyncio.TimeoutError: + raise asyncio.TimeoutError( + "Health check execution exceeds health_check_interval" + ) for result in results: if isinstance(result, UnhealthyDatabaseException): diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index b5bf695380..d6d2d38814 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -4,11 +4,10 @@ from enum import Enum from typing import Optional, Tuple, Union, List -from pygments.lexers.julia import allowed_variable from redis.asyncio import Redis from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT -from redis.asyncio.retry import Retry +from redis.retry import Retry from redis.backoff import NoBackoff from redis.http.http_client import HttpClient from redis.multidb.exception import UnhealthyDatabaseException diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 986e773915..af0f68f95b 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -265,7 +265,7 @@ def request( return self.retry.call_with_retry( lambda: self._make_request(req, context=context, timeout=timeout), lambda _: dummy_fail(), - lambda error: self._is_retryable_http_error(error), + lambda error: self._is_retryable_http_error(error) ) except HTTPError as e: # Read error body, build response, and decide on retry diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 19f846bd29..c132465cd7 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -245,20 +245,23 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases } - for future in as_completed(futures, timeout=self._health_check_interval): - try: - future.result() - except UnhealthyDatabaseException as e: - unhealthy_db = e.database - unhealthy_db.circuit.state = CBState.OPEN - - logger.exception( - 'Health check failed, due to exception', - exc_info=e.original_exception - ) - - if on_error: - on_error(e.original_exception) + try: + for future in as_completed(futures, timeout=self._health_check_interval): + try: + future.result() + except UnhealthyDatabaseException as e: + unhealthy_db = e.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + 'Health check failed, due to exception', + exc_info=e.original_exception + ) + + if on_error: + on_error(e.original_exception) + except TimeoutError: + raise TimeoutError("Health check execution exceeds health_check_interval") def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index f99bef926d..967c72addf 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -99,18 +99,18 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in LagAwareHealthCheck( verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')), - lag_aware_tolerance=10000 ) - ] + ], + "health_check_interval": 20, }, {"client_class": RedisCluster, "failure_threshold": 2, "health_checks": [ LagAwareHealthCheck( verify_tls=False, auth_basic=(os.getenv('ENV0_USERNAME'), os.getenv('ENV0_PASSWORD')), - lag_aware_tolerance=10000 ) - ] + ], + "health_check_interval": 20, }, ], ids=["standalone", "cluster"], diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index f4eb8efe59..a0f7c79233 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -96,8 +96,8 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2, "health_check_interval": 10}, - {"client_class": RedisCluster, "failure_threshold": 2, "health_check_interval": 10}, + {"client_class": Redis, "failure_threshold": 2, "health_check_interval": 20}, + {"client_class": RedisCluster, "failure_threshold": 2, "health_check_interval": 20}, ], ids=["standalone", "cluster"], indirect=True From ca8166c47492ffe9e873b2efa4895803ea09a877 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 11:49:41 +0300 Subject: [PATCH 19/50] Decreased retry cap, increased failure delay --- tests/test_asyncio/test_scenario/conftest.py | 2 +- tests/test_asyncio/test_scenario/test_active_active.py | 2 +- tests/test_scenario/conftest.py | 2 +- tests/test_scenario/test_active_active.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index e7302e2d60..d8d254557d 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -38,7 +38,7 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=1, base=0.05), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 967c72addf..dd4a556d9d 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -22,7 +22,7 @@ async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + parameters={"bdb_id": config['bdb_id'], "delay": 3, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index c9f75948b3..996095e0be 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -75,7 +75,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=1, base=0.05), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index a0f7c79233..d641fb65bb 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -21,7 +21,7 @@ def trigger_network_failure_action(fault_injector_client, config, event: threading.Event = None): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + parameters={"bdb_id": config['bdb_id'], "delay": 3, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) From f50299ecc8e2028479b9172359cb80d383126987 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 12:51:44 +0300 Subject: [PATCH 20/50] Fixed async teardown --- redis/asyncio/multidb/client.py | 3 ++ redis/multidb/client.py | 3 ++ tests/test_asyncio/test_scenario/conftest.py | 16 +++++-- .../test_scenario/test_active_active.py | 43 ++++++++----------- tests/test_scenario/conftest.py | 2 +- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 6d08e096d3..9a78282d8b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -305,6 +305,9 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + async def aclose(self): + await self.command_executor.active_database.client.aclose() + def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/client.py b/redis/multidb/client.py index c132465cd7..6f2022c9de 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -271,6 +271,9 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + def close(self): + self.command_executor.active_database.client.close() + def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index d8d254557d..73607634ca 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -1,9 +1,12 @@ +import asyncio import os +from typing import Any, AsyncGenerator import pytest import pytest_asyncio from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ MultiDbConfig from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged @@ -27,7 +30,7 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest_asyncio.fixture() -async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChangedListener, dict]: +async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActiveDatabaseChangedListener, Any], Any]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -38,7 +41,7 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=1, base=0.05), retries=10)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. @@ -86,4 +89,11 @@ async def r_multi_db(request) -> tuple[MultiDbConfig, CheckActiveDatabaseChanged event_dispatcher=event_dispatcher, ) - return config, listener, endpoint_config \ No newline at end of file + client = MultiDBClient(config) + + async def teardown(): + await client.aclose() + await asyncio.sleep(15) + + yield client, listener, endpoint_config + await teardown() \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index dd4a556d9d..b49cde85da 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -39,11 +39,6 @@ async def trigger_network_failure_action(fault_injector_client, config, event: a logger.info(f"Action completed. Status: {status_result['status']}") class TestActiveActive: - - def teardown_method(self, method): - # Timeout so the cluster could recover from network failure. - sleep(15) - @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", @@ -54,9 +49,9 @@ def teardown_method(self, method): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(100) + @pytest.mark.timeout(200) async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db # Handle unavailable databases from previous test. retry = Retry( @@ -65,7 +60,7 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) ) - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -116,16 +111,16 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(100) + @pytest.mark.timeout(200) async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) ) - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -160,9 +155,9 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(100) + @pytest.mark.timeout(200) async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, @@ -179,7 +174,7 @@ async def callback(): pipe.get('{hash}key3') assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -209,9 +204,9 @@ async def callback(): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(100) + @pytest.mark.timeout(200) async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, @@ -228,7 +223,7 @@ async def callback(): pipe.get('{hash}key3') assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -258,9 +253,9 @@ async def callback(): ids=["standalone", "cluster"], indirect=True ) - @pytest.mark.timeout(100) + @pytest.mark.timeout(200) async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), @@ -276,7 +271,7 @@ async def callback(pipe: Pipeline): pipe.get('{hash}key2') pipe.get('{hash}key3') - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -302,9 +297,9 @@ async def callback(pipe: Pipeline): [{"failure_threshold": 2}], indirect=True ) - @pytest.mark.timeout(60) + @pytest.mark.timeout(200) async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): - client_config, listener, endpoint_config = r_multi_db + client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, @@ -318,7 +313,7 @@ async def handler(message): nonlocal messages_count messages_count += 1 - async with MultiDBClient(client_config) as r_multi_db: + async with client as r_multi_db: event = asyncio.Event() asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) @@ -358,4 +353,4 @@ async def handler(message): await asyncio.sleep(0.1) task.cancel() await pubsub.unsubscribe('test-channel') is True - assert messages_count >= 5 \ No newline at end of file + assert messages_count >= 2 \ No newline at end of file diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 996095e0be..6c32cf0699 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -75,7 +75,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=1, base=0.05), retries=10)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. From 115d996865c5ab92dfa5d28c2f10ef2c1b38079e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 13:22:41 +0300 Subject: [PATCH 21/50] Fixed tests --- redis/asyncio/multidb/event.py | 2 +- tests/test_asyncio/test_scenario/conftest.py | 6 +++- .../test_scenario/test_active_active.py | 29 +++++++++---------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py index ea5534ce86..9b74367b34 100644 --- a/redis/asyncio/multidb/event.py +++ b/redis/asyncio/multidb/event.py @@ -51,7 +51,7 @@ async def listen(self, event: AsyncActiveDatabaseChanged): new_pubsub.patterns = old_pubsub.patterns await new_pubsub.on_connect(None) event.command_executor.active_pubsub = new_pubsub - await old_pubsub.close() + await old_pubsub.aclose() class RegisterCommandFailure(AsyncEventListenerInterface): """ diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 73607634ca..4b717efbda 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -5,7 +5,7 @@ import pytest import pytest_asyncio -from redis.asyncio import Redis +from redis.asyncio import Redis, RedisCluster from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ MultiDbConfig @@ -93,6 +93,10 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive async def teardown(): await client.aclose() + + if isinstance(client.command_executor.active_database.client, Redis): + await client.command_executor.active_database.client.connection_pool.disconnect() + await asyncio.sleep(15) yield client, listener, endpoint_config diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index b49cde85da..693dd33bbe 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -186,13 +186,13 @@ async def callback(): ) await asyncio.sleep(0.5) - # Execute commands until database failover - while not listener.is_changed_flag: - await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() - ) - await asyncio.sleep(0.5) + # Execute commands until database failover + while not listener.is_changed_flag: + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -235,13 +235,13 @@ async def callback(): ) await asyncio.sleep(0.5) - # Execute pipeline until database failover - while not listener.is_changed_flag: - await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() - ) - await asyncio.sleep(0.5) + # Execute pipeline until database failover + while not listener.is_changed_flag: + await retry.call_with_retry( + lambda: callback(), + lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -352,5 +352,4 @@ async def handler(message): # A timeout to ensure that an async handler will handle all previous messages. await asyncio.sleep(0.1) task.cancel() - await pubsub.unsubscribe('test-channel') is True assert messages_count >= 2 \ No newline at end of file From 6556e8a2e929918bf75cbefaa5cb1892a02422cf Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 15:18:28 +0300 Subject: [PATCH 22/50] Added graceful connection closing, added graceful hc tasks termination --- redis/asyncio/multidb/client.py | 17 +++++++++-------- redis/asyncio/multidb/command_executor.py | 5 +++-- redis/asyncio/multidb/event.py | 11 +++++++++++ redis/multidb/command_executor.py | 6 ++++-- redis/multidb/event.py | 13 +++++++++++++ 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 9a78282d8b..71fbd9c2e0 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -59,7 +59,8 @@ def __init__(self, config: MultiDbConfig): self._hc_lock = asyncio.Lock() self._bg_scheduler = BackgroundScheduler() self._config = config - self._hc_task = None + self._recurring_hc_task = None + self._hc_tasks = [] self._half_open_state_task = None async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": @@ -68,10 +69,12 @@ async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": return self async def __aexit__(self, exc_type, exc_value, traceback): - if self._hc_task: - self._hc_task.cancel() + if self._recurring_hc_task: + self._recurring_hc_task.cancel() if self._half_open_state_task: self._half_open_state_task.cancel() + for hc_task in self._hc_tasks: + hc_task.cancel() async def initialize(self): """ @@ -84,7 +87,7 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._recurring_hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, )) @@ -251,12 +254,10 @@ async def _check_databases_health( Runs health checks against all databases. """ try: + self._hc_tasks = [asyncio.create_task(self._check_db_health(database)) for database, _ in self._databases] results = await asyncio.wait_for( asyncio.gather( - *( - asyncio.create_task(self._check_db_health(database)) - for database, _ in self._databases - ), + *self._hc_tasks, return_exceptions=True, ), timeout=self._health_check_interval, diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 7e622d6260..95209298f4 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -7,7 +7,7 @@ from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ - ResubscribeOnActiveDatabaseChanged + ResubscribeOnActiveDatabaseChanged, CloseConnectionOnActiveDatabaseChanged from redis.asyncio.multidb.failover import AsyncFailoverStrategy, FailoverStrategyExecutor, DefaultFailoverStrategyExecutor, \ DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY from redis.asyncio.multidb.failure_detector import AsyncFailureDetector @@ -286,7 +286,8 @@ def _setup_event_dispatcher(self): """ failure_listener = RegisterCommandFailure(self._failure_detectors) resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + close_connection_listener = CloseConnectionOnActiveDatabaseChanged() self._event_dispatcher.register_listeners({ AsyncOnCommandsFailEvent: [failure_listener], - AsyncActiveDatabaseChanged: [resubscribe_listener], + AsyncActiveDatabaseChanged: [close_connection_listener, resubscribe_listener], }) \ No newline at end of file diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py index 9b74367b34..a2c90eed40 100644 --- a/redis/asyncio/multidb/event.py +++ b/redis/asyncio/multidb/event.py @@ -1,5 +1,6 @@ from typing import List +from redis.asyncio import Redis from redis.asyncio.multidb.database import AsyncDatabase from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent @@ -53,6 +54,16 @@ async def listen(self, event: AsyncActiveDatabaseChanged): event.command_executor.active_pubsub = new_pubsub await old_pubsub.aclose() +class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Close connection to the old active database. + """ + async def listen(self, event: AsyncActiveDatabaseChanged): + await event.old_database.client.aclose() + + if isinstance(event.old_database.client, Redis): + await event.old_database.client.connection_pool.disconnect() + class RegisterCommandFailure(AsyncEventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 7ca7d2ec52..562dcfd6fe 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -7,7 +7,8 @@ from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState -from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged +from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged, \ + CloseConnectionOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy, FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor from redis.multidb.failure_detector import FailureDetector @@ -303,7 +304,8 @@ def _setup_event_dispatcher(self): """ failure_listener = RegisterCommandFailure(self._failure_detectors) resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + close_connection_listener = CloseConnectionOnActiveDatabaseChanged() self._event_dispatcher.register_listeners({ OnCommandsFailEvent: [failure_listener], - ActiveDatabaseChanged: [resubscribe_listener], + ActiveDatabaseChanged: [close_connection_listener, resubscribe_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index bca9482347..75da84fdf8 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,5 +1,8 @@ from typing import List +from redis.client import Redis +from sphinx.events import EventListener + from redis.event import EventListenerInterface, OnCommandsFailEvent from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector @@ -53,6 +56,16 @@ def listen(self, event: ActiveDatabaseChanged): event.command_executor.active_pubsub = new_pubsub old_pubsub.close() +class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface): + """ + Close connection to the old active database. + """ + def listen(self, event: ActiveDatabaseChanged): + event.old_database.client.close() + + if isinstance(event.old_database.client, Redis): + event.old_database.client.connection_pool.disconnect() + class RegisterCommandFailure(EventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. From 063e795c85771801741f67f4dbab6001dd726660 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 15:33:52 +0300 Subject: [PATCH 23/50] Make sure active connection will be disconnected on failover --- redis/multidb/event.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 75da84fdf8..26dd9aa1a0 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,7 +1,6 @@ from typing import List from redis.client import Redis -from sphinx.events import EventListener from redis.event import EventListenerInterface, OnCommandsFailEvent from redis.multidb.database import SyncDatabase @@ -64,6 +63,7 @@ def listen(self, event: ActiveDatabaseChanged): event.old_database.client.close() if isinstance(event.old_database.client, Redis): + event.old_database.client.connection_pool.update_active_connections_for_reconnect() event.old_database.client.connection_pool.disconnect() class RegisterCommandFailure(EventListenerInterface): From 3ef34b16b1f27020db785ae4e7d2b5fbb11f1389 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Sep 2025 16:00:55 +0300 Subject: [PATCH 24/50] Close cluster connection on failover --- redis/asyncio/connection.py | 18 ++++++++++++++++++ redis/asyncio/multidb/event.py | 1 + redis/multidb/event.py | 4 ++++ tests/test_asyncio/test_scenario/conftest.py | 2 +- 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..a42a336024 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -212,6 +212,7 @@ def __init__( self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None + self._should_reconnect = False try: p = int(protocol) @@ -342,6 +343,12 @@ async def connect_check_health( if task and inspect.isawaitable(task): await task + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + @abstractmethod async def _connect(self): pass @@ -1198,6 +1205,9 @@ async def release(self, connection: AbstractConnection): # Connections should always be returned to the correct pool, # not doing so is an error that will cause an exception here. self._in_use_connections.remove(connection) + if connection.should_reconnect(): + await connection.disconnect() + self._available_connections.append(connection) await self._event_dispatcher.dispatch_async( AsyncAfterConnectionReleasedEvent(connection) @@ -1225,6 +1235,14 @@ async def disconnect(self, inuse_connections: bool = True): if exc: raise exc + async def update_active_connections_for_reconnect(self): + """ + Mark all active connections for reconnect. + """ + async with self._lock: + for conn in self._in_use_connections: + conn.mark_for_reconnect() + async def aclose(self) -> None: """Close the pool, disconnecting all connections""" await self.disconnect() diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py index a2c90eed40..9f6f463a0f 100644 --- a/redis/asyncio/multidb/event.py +++ b/redis/asyncio/multidb/event.py @@ -62,6 +62,7 @@ async def listen(self, event: AsyncActiveDatabaseChanged): await event.old_database.client.aclose() if isinstance(event.old_database.client, Redis): + await event.old_database.client.connection_pool.update_active_connections_for_reconnect() await event.old_database.client.connection_pool.disconnect() class RegisterCommandFailure(AsyncEventListenerInterface): diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 26dd9aa1a0..8a76139752 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -65,6 +65,10 @@ def listen(self, event: ActiveDatabaseChanged): if isinstance(event.old_database.client, Redis): event.old_database.client.connection_pool.update_active_connections_for_reconnect() event.old_database.client.connection_pool.disconnect() + else: + for node in event.old_database.client.nodes_manager.nodes_cache.values(): + node.redis_connection.connection_pool.update_active_connections_for_reconnect() + node.redis_connection.connection_pool.disconnect() class RegisterCommandFailure(EventListenerInterface): """ diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 4b717efbda..00ddd36a76 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -45,7 +45,7 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. - health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + health_check_interval = request.param.get('health_check_interval', 10) health_checks = request.param.get('health_checks', []) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() From a1c0633e69dd5a6c8bc28ea3705e61f80f1175e4 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:02:55 +0300 Subject: [PATCH 25/50] Refactored Failure Detector (#3775) --- redis/asyncio/multidb/command_executor.py | 23 ++- redis/asyncio/multidb/config.py | 22 ++- redis/asyncio/multidb/failure_detector.py | 11 +- redis/multidb/command_executor.py | 20 ++- redis/multidb/config.py | 20 ++- redis/multidb/failure_detector.py | 62 ++++--- tests/test_asyncio/test_multidb/conftest.py | 5 +- .../test_multidb/test_command_executor.py | 3 + .../test_multidb/test_failure_detector.py | 153 +++++++----------- tests/test_asyncio/test_scenario/conftest.py | 9 +- .../test_scenario/test_active_active.py | 22 +-- tests/test_multidb/conftest.py | 5 +- tests/test_multidb/test_command_executor.py | 5 +- tests/test_multidb/test_failure_detector.py | 149 +++++++---------- tests/test_scenario/conftest.py | 6 +- tests/test_scenario/test_active_active.py | 30 ++-- 16 files changed, 277 insertions(+), 268 deletions(-) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 95209298f4..2526c4ed9e 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -199,7 +199,9 @@ def pubsub(self, **kwargs): async def execute_command(self, *args, **options): async def callback(): - return await self._active_database.client.execute_command(*args, **options) + response = await self._active_database.client.execute_command(*args, **options) + await self._register_command_execution(args) + return response return await self._execute_with_failure_detection(callback, args) @@ -209,7 +211,9 @@ async def callback(): for command, options in command_stack: pipe.execute_command(*command, **options) - return await pipe.execute() + response = await pipe.execute() + await self._register_command_execution(command_stack) + return response return await self._execute_with_failure_detection(callback, command_stack) @@ -222,13 +226,15 @@ async def execute_transaction( watch_delay: Optional[float] = None, ): async def callback(): - return await self._active_database.client.transaction( + response = await self._active_database.client.transaction( func, *watches, shard_hint=shard_hint, value_from_callable=value_from_callable, watch_delay=watch_delay ) + await self._register_command_execution(()) + return response return await self._execute_with_failure_detection(callback) @@ -236,9 +242,12 @@ async def execute_pubsub_method(self, method_name: str, *args, **kwargs): async def callback(): method = getattr(self.active_pubsub, method_name) if iscoroutinefunction(method): - return await method(*args, **kwargs) + response = await method(*args, **kwargs) else: - return method(*args, **kwargs) + response = method(*args, **kwargs) + + await self._register_command_execution(args) + return response return await self._execute_with_failure_detection(callback, *args) @@ -280,6 +289,10 @@ async def _check_active_database(self): async def _on_command_fail(self, error, *args): await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + async def _register_command_execution(self, cmd: tuple): + for detector in self._failure_detectors: + await detector.register_command_execution(cmd) + def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 354bbcf5c7..af2029c110 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -7,8 +7,7 @@ from redis.asyncio.multidb.database import Databases, Database from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_DELAY, \ DEFAULT_FAILOVER_ATTEMPTS -from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper, \ - DEFAULT_FAILURES_THRESHOLD, DEFAULT_FAILURES_DURATION +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper from redis.asyncio.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY from redis.asyncio.retry import Retry @@ -16,7 +15,8 @@ from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter, DEFAULT_GRACE_PERIOD -from redis.multidb.failure_detector import CommandFailureDetector +from redis.multidb.failure_detector import CommandFailureDetector, DEFAULT_MIN_NUM_FAILURES, \ + DEFAULT_FAILURE_RATE_THRESHOLD, DEFAULT_FAILURES_DETECTION_WINDOW DEFAULT_AUTO_FALLBACK_INTERVAL = 120 @@ -70,8 +70,9 @@ class MultiDbConfig: client_class: The client class used to manage database connections. command_retry: Retry strategy for executing database commands. failure_detectors: Optional list of additional failure detectors for monitoring database failures. - failure_threshold: Threshold for determining database failure. - failures_interval: Time interval for tracking database failures. + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failures_detection_window: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. health_check_probes: Number of attempts to evaluate the health of a database. @@ -105,8 +106,9 @@ class MultiDbConfig: backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 ) failure_detectors: Optional[List[AsyncFailureDetector]] = None - failure_threshold: int = DEFAULT_FAILURES_THRESHOLD - failures_interval: float = DEFAULT_FAILURES_DURATION + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD + failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES @@ -151,7 +153,11 @@ def databases(self) -> Databases: def default_failure_detectors(self) -> List[AsyncFailureDetector]: return [ FailureDetectorAsyncWrapper( - CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval) + CommandFailureDetector( + min_num_failures=self.min_num_failures, + failure_rate_threshold=self.failure_rate_threshold, + failure_detection_window=self.failures_detection_window + ) ), ] diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py index cdfcc6ff1e..e6d257e941 100644 --- a/redis/asyncio/multidb/failure_detector.py +++ b/redis/asyncio/multidb/failure_detector.py @@ -2,9 +2,6 @@ from redis.multidb.failure_detector import FailureDetector -DEFAULT_FAILURES_THRESHOLD = 1000 -DEFAULT_FAILURES_DURATION = 2 - class AsyncFailureDetector(ABC): @abstractmethod @@ -12,6 +9,11 @@ async def register_failure(self, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" pass + @abstractmethod + async def register_command_execution(self, cmd: tuple) -> None: + """Register a command execution.""" + pass + @abstractmethod def set_command_executor(self, command_executor) -> None: """Set the command executor for this failure.""" @@ -27,5 +29,8 @@ def __init__(self, failure_detector: FailureDetector) -> None: async def register_failure(self, exception: Exception, cmd: tuple) -> None: self._failure_detector.register_failure(exception, cmd) + async def register_command_execution(self, cmd: tuple) -> None: + self._failure_detector.register_command_execution(cmd) + def set_command_executor(self, command_executor) -> None: self._failure_detector.set_command_executor(command_executor) \ No newline at end of file diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 562dcfd6fe..481364de9a 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -224,7 +224,9 @@ def failover_strategy_executor(self) -> FailoverStrategyExecutor: def execute_command(self, *args, **options): def callback(): - return self._active_database.client.execute_command(*args, **options) + response = self._active_database.client.execute_command(*args, **options) + self._register_command_execution(args) + return response return self._execute_with_failure_detection(callback, args) @@ -234,13 +236,17 @@ def callback(): for command, options in command_stack: pipe.execute_command(*command, **options) - return pipe.execute() + response = pipe.execute() + self._register_command_execution(command_stack) + return response return self._execute_with_failure_detection(callback, command_stack) def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): def callback(): - return self._active_database.client.transaction(transaction, *watches, **options) + response = self._active_database.client.transaction(transaction, *watches, **options) + self._register_command_execution(()) + return response return self._execute_with_failure_detection(callback) @@ -256,7 +262,9 @@ def callback(): def execute_pubsub_method(self, method_name: str, *args, **kwargs): def callback(): method = getattr(self.active_pubsub, method_name) - return method(*args, **kwargs) + response = method(*args, **kwargs) + self._register_command_execution(args) + return response return self._execute_with_failure_detection(callback, *args) @@ -298,6 +306,10 @@ def _check_active_database(self): self.active_database = self._failover_strategy_executor.execute() self._schedule_next_fallback() + def _register_command_execution(self, cmd: tuple): + for detector in self._failure_detectors: + detector.register_command_execution(cmd) + def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/multidb/config.py b/redis/multidb/config.py index ff9872ffd4..f78114f014 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -11,8 +11,8 @@ from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database, Databases -from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector, DEFAULT_FAILURES_THRESHOLD, \ - DEFAULT_FAILURES_DURATION +from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector, DEFAULT_MIN_NUM_FAILURES, \ + DEFAULT_FAILURES_DETECTION_WINDOW, DEFAULT_FAILURE_RATE_THRESHOLD from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES, \ DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, \ @@ -71,8 +71,9 @@ class MultiDbConfig: client_class: The client class used to manage database connections. command_retry: Retry strategy for executing database commands. failure_detectors: Optional list of additional failure detectors for monitoring database failures. - failure_threshold: Threshold for determining database failure. - failures_interval: Time interval for tracking database failures. + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failures_detection_window: Time interval for tracking database failures. health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. health_check_probes: Number of attempts to evaluate the health of a database. @@ -107,8 +108,9 @@ class MultiDbConfig: backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 ) failure_detectors: Optional[List[FailureDetector]] = None - failure_threshold: int = DEFAULT_FAILURES_THRESHOLD - failures_interval: float = DEFAULT_FAILURES_DURATION + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD + failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES @@ -152,7 +154,11 @@ def databases(self) -> Databases: def default_failure_detectors(self) -> List[FailureDetector]: return [ - CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval), + CommandFailureDetector( + min_num_failures=self.min_num_failures, + failure_rate_threshold=self.failure_rate_threshold, + failure_detection_window=self.failures_detection_window + ), ] def default_health_checks(self) -> List[HealthCheck]: diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 6b918b152a..ca657c4e52 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -1,3 +1,4 @@ +import math import threading from abc import ABC, abstractmethod from datetime import datetime, timedelta @@ -7,8 +8,9 @@ from redis.multidb.circuit import State as CBState -DEFAULT_FAILURES_THRESHOLD = 1000 -DEFAULT_FAILURES_DURATION = 2 +DEFAULT_MIN_NUM_FAILURES = 1000 +DEFAULT_FAILURE_RATE_THRESHOLD = 0.1 +DEFAULT_FAILURES_DETECTION_WINDOW = 2 class FailureDetector(ABC): @@ -17,6 +19,11 @@ def register_failure(self, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" pass + @abstractmethod + def register_command_execution(self, cmd: tuple) -> None: + """Register a command execution.""" + pass + @abstractmethod def set_command_executor(self, command_executor) -> None: """Set the command executor for this failure.""" @@ -28,56 +35,65 @@ class CommandFailureDetector(FailureDetector): """ def __init__( self, - threshold: int = DEFAULT_FAILURES_THRESHOLD, - duration: float = DEFAULT_FAILURES_DURATION, + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES, + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD, + failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW, error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ Initialize a new CommandFailureDetector instance. Args: - threshold: The number of failures that must occur within the duration to trigger failure detection. - duration: The time window in seconds during which failures are counted. + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failure_detection_window: Time interval for executing health checks. error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted. The detector tracks command failures within a sliding time window. When the number of failures exceeds the threshold within the specified duration, it triggers failure detection. """ self._command_executor = None - self._threshold = threshold - self._duration = duration + self._min_num_failures = min_num_failures + self._failure_rate_threshold = failure_rate_threshold + self._failure_detection_window = failure_detection_window self._error_types = error_types + self._commands_executed: int = 0 self._start_time: datetime = datetime.now() - self._end_time: datetime = self._start_time + timedelta(seconds=self._duration) - self._failures_within_duration: List[tuple[datetime, tuple]] = [] + self._end_time: datetime = self._start_time + timedelta(seconds=self._failure_detection_window) + self._failures_count: int = 0 self._lock = threading.RLock() def register_failure(self, exception: Exception, cmd: tuple) -> None: - failure_time = datetime.now() - - if not self._start_time < failure_time < self._end_time: - self._reset() - with self._lock: if self._error_types: if type(exception) in self._error_types: - self._failures_within_duration.append((datetime.now(), cmd)) + self._failures_count += 1 else: - self._failures_within_duration.append((datetime.now(), cmd)) + self._failures_count += 1 - self._check_threshold() + self._check_threshold() def set_command_executor(self, command_executor) -> None: self._command_executor = command_executor - def _check_threshold(self): + def register_command_execution(self, cmd: tuple) -> None: with self._lock: - if len(self._failures_within_duration) >= self._threshold: - self._command_executor.active_database.circuit.state = CBState.OPEN + if not self._start_time < datetime.now() < self._end_time: self._reset() + self._commands_executed += 1 + + def _check_threshold(self): + if ( + self._failures_count >= self._min_num_failures + and self._failures_count >= (math.ceil(self._commands_executed * self._failure_rate_threshold)) + ): + self._command_executor.active_database.circuit.state = CBState.OPEN + self._reset() + def _reset(self) -> None: with self._lock: self._start_time = datetime.now() - self._end_time = self._start_time + timedelta(seconds=self._duration) - self._failures_within_duration = [] \ No newline at end of file + self._end_time = self._start_time + timedelta(seconds=self._failure_detection_window) + self._failures_count = 0 + self._commands_executed = 0 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index f5ea12d9b0..7695332754 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -9,7 +9,7 @@ DEFAULT_HEALTH_CHECK_POLICY from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.asyncio import Redis +from redis.asyncio import Redis, ConnectionPool from redis.asyncio.multidb.database import Database, Databases @pytest.fixture() @@ -37,6 +37,7 @@ def mock_db(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) @@ -51,6 +52,7 @@ def mock_db1(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) @@ -65,6 +67,7 @@ def mock_db2(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py index 3f64e6aa0b..01a8326e5a 100644 --- a/tests/test_asyncio/test_multidb/test_command_executor.py +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -46,6 +46,7 @@ async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_ await executor.set_active_database(mock_db2) assert await executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 + assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -82,6 +83,7 @@ async def test_execute_command_automatically_select_active_database( assert await executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 assert mock_selector.call_count == 2 + assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -124,6 +126,7 @@ async def test_execute_command_fallback_to_another_db_after_fallback_interval( assert await executor.execute_command('SET', 'key', 'value') == 'OK1' assert mock_ed.register_listeners.call_count == 1 assert mock_selector.call_count == 3 + assert mock_fd.register_command_execution.call_count == 3 @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py index 3c1eb4fabd..a4d7407609 100644 --- a/tests/test_asyncio/test_multidb/test_failure_detector.py +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -4,6 +4,7 @@ import pytest from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.database import Database from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper from redis.multidb.circuit import State as CBState from redis.multidb.failure_detector import CommandFailureDetector @@ -12,127 +13,95 @@ class TestFailureDetectorAsyncWrapper: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db', + 'min_num_failures,failure_rate_threshold,circuit_state', [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + (2, 0.4, CBState.OPEN), + (2, 0, CBState.OPEN), + (0, 0.4, CBState.OPEN), + (3, 0.4, CBState.CLOSED), + (2, 0.41, CBState.CLOSED), ], - indirect=True, - ) - async def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) - mock_ce = Mock(spec=AsyncCommandExecutor) - mock_ce.active_database = mock_db - fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.OPEN - - @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ids=[ + "exceeds min num failures AND failures rate", + "exceeds min num failures AND failures rate == 0", + "min num failures == 0 AND exceeds failures rate", + "do not exceeds min num failures", + "do not exceeds failures rate", ], - indirect=True, ) - async def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + async def test_failure_detector_correctly_reacts_to_failures( + self, + min_num_failures, + failure_rate_threshold, + circuit_state + ): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(min_num_failures, failure_rate_threshold)) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=AsyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_command_execution(('GET', 'key')) + await fd.register_command_execution(('GET','key')) + await fd.register_failure(Exception(), ('GET', 'key')) - assert mock_db.circuit.state == CBState.CLOSED + await fd.register_command_execution(('GET', 'key')) + await fd.register_command_execution(('GET','key')) + await fd.register_command_execution(('GET','key')) + await fd.register_failure(Exception(), ('GET', 'key')) - @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - ], - indirect=True, - ) - async def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) - mock_ce = Mock(spec=AsyncCommandExecutor) - mock_ce.active_database = mock_db - fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await asyncio.sleep(0.1) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await asyncio.sleep(0.1) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await asyncio.sleep(0.1) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await asyncio.sleep(0.1) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.CLOSED - - # 4 more failures as the last one already refreshed timer - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.OPEN + assert mock_db.circuit.state == circuit_state @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db', + 'min_num_failures,failure_rate_threshold', [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + (3, 0.0), + (3, 0.6), + ], + ids=[ + "do not exceeds min num failures, during interval", + "do not exceeds min num failures AND failure rate, during interval", ], - indirect=True, ) - async def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + async def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min_num_failures, failure_rate_threshold): + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) + ) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=AsyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await asyncio.sleep(0.4) + await fd.register_command_execution(('GET', 'key')) + await fd.register_failure(Exception(), ('GET', 'key')) + await asyncio.sleep(0.16) + await fd.register_command_execution(('GET', 'key')) + await fd.register_command_execution(('GET', 'key')) + await fd.register_command_execution(('GET', 'key')) + await fd.register_failure(Exception(), ('GET', 'key')) + await asyncio.sleep(0.16) + await fd.register_command_execution(('GET', 'key')) + await fd.register_failure(Exception(), ('GET', 'key')) assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + # 2 more failure as last one already refreshed timer + await fd.register_command_execution(('GET', 'key')) + await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(('GET', 'key')) + await fd.register_failure(Exception(), ('GET', 'key')) assert mock_db.circuit.state == CBState.OPEN @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - ], - indirect=True, - ) - async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self): fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=AsyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 00ddd36a76..c152e51fb5 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -7,12 +7,13 @@ from redis.asyncio import Redis, RedisCluster from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ +from redis.asyncio.multidb.config import DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ MultiDbConfig from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff from redis.event import AsyncEventListenerInterface, EventDispatcher +from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES from tests.test_scenario.conftest import get_endpoints_config, extract_cluster_fqdn from tests.test_scenario.fault_injector_client import FaultInjectorClient @@ -40,7 +41,7 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) - failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + min_num_failures = request.param.get('min_num_failures', DEFAULT_MIN_NUM_FAILURES) command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) # Retry configuration different for health checks as initial health check require more time in case @@ -82,7 +83,7 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive client_class=client_class, databases_config=db_configs, command_retry=command_retry, - failure_threshold=failure_threshold, + min_num_failures=min_num_failures, health_checks=health_checks, health_check_probes=3, health_check_interval=health_check_interval, @@ -97,7 +98,7 @@ async def teardown(): if isinstance(client.command_executor.active_database.client, Redis): await client.command_executor.active_database.client.connection_pool.disconnect() - await asyncio.sleep(15) + await asyncio.sleep(20) yield client, listener, endpoint_config await teardown() \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 693dd33bbe..55b604528b 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -43,8 +43,8 @@ class TestActiveActive: @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -89,7 +89,7 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2, "health_checks": + {"client_class": Redis, "min_num_failures": 2, "health_checks": [ LagAwareHealthCheck( verify_tls=False, @@ -98,7 +98,7 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in ], "health_check_interval": 20, }, - {"client_class": RedisCluster, "failure_threshold": 2, "health_checks": + {"client_class": RedisCluster, "min_num_failures": 2, "health_checks": [ LagAwareHealthCheck( verify_tls=False, @@ -149,8 +149,8 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -198,8 +198,8 @@ async def callback(): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -247,8 +247,8 @@ async def callback(): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -294,7 +294,7 @@ async def callback(pipe: Pipeline): @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", - [{"failure_threshold": 2}], + [{"min_num_failures": 2}], indirect=True ) @pytest.mark.timeout(200) diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 3b1f7f369b..ce4658868f 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -2,7 +2,7 @@ import pytest -from redis import Redis +from redis import Redis, ConnectionPool from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ @@ -39,6 +39,7 @@ def mock_db(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) @@ -53,6 +54,7 @@ def mock_db1(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) @@ -67,6 +69,7 @@ def mock_db2(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) cb = request.param.get("circuit", {}) mock_cb = Mock(spec=CircuitBreaker) diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 044fef0f8c..2001d64f04 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -44,6 +44,7 @@ def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, m executor.active_database = mock_db2 assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 + assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -78,6 +79,7 @@ def test_execute_command_automatically_select_active_database( assert executor.execute_command('SET', 'key', 'value') == 'OK2' assert mock_ed.register_listeners.call_count == 1 assert mock_fs.database.call_count == 2 + assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -118,6 +120,7 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( assert executor.execute_command('SET', 'key', 'value') == 'OK1' assert mock_ed.register_listeners.call_count == 1 assert mock_fs.database.call_count == 3 + assert mock_fd.register_command_execution.call_count == 3 @pytest.mark.parametrize( 'mock_db,mock_db1,mock_db2', @@ -137,7 +140,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] threshold = 3 - fd = CommandFailureDetector(threshold, 1) + fd = CommandFailureDetector(threshold, 0.0, 1) ed = EventDispatcher() databases = create_weighted_list(mock_db, mock_db1, mock_db2) diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 28687f2a11..3e71ab6aa5 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -4,6 +4,7 @@ import pytest from redis.multidb.command_executor import SyncCommandExecutor +from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -11,123 +12,91 @@ class TestCommandFailureDetector: @pytest.mark.parametrize( - 'mock_db', + 'min_num_failures,failure_rate_threshold,circuit_state', [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + (2, 0.4, CBState.OPEN), + (2, 0, CBState.OPEN), + (0, 0.4, CBState.OPEN), + (3, 0.4, CBState.CLOSED), + (2, 0.41, CBState.CLOSED), ], - indirect=True, - ) - def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): - fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=SyncCommandExecutor) - mock_ce.active_database = mock_db - fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.OPEN - - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ids=[ + "exceeds min num failures AND failures rate", + "exceeds min num failures AND failures rate == 0", + "min num failures == 0 AND exceeds failures rate", + "do not exceeds min num failures", + "do not exceeds failures rate", ], - indirect=True, ) - def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): - fd = CommandFailureDetector(5, 1) + def test_failure_detector_correctly_reacts_to_failures( + self, + min_num_failures, + failure_rate_threshold, + circuit_state + ): + fd = CommandFailureDetector(min_num_failures, failure_rate_threshold) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_command_execution(('GET', 'key')) + fd.register_command_execution(('GET','key')) + fd.register_failure(Exception(), ('GET', 'key')) - assert mock_db.circuit.state == CBState.CLOSED + fd.register_command_execution(('GET', 'key')) + fd.register_command_execution(('GET','key')) + fd.register_command_execution(('GET','key')) + fd.register_failure(Exception(), ('GET', 'key')) - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - ], - indirect=True, - ) - def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): - fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=SyncCommandExecutor) - mock_ce.active_database = mock_db - fd.set_command_executor(mock_ce) - assert mock_db.circuit.state == CBState.CLOSED - - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - sleep(0.1) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.CLOSED - - # 4 more failure as last one already refreshed timer - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.OPEN + assert mock_db.circuit.state == circuit_state @pytest.mark.parametrize( - 'mock_db', + 'min_num_failures,failure_rate_threshold', [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + (3, 0.0), + (3, 0.6), + ], + ids=[ + "do not exceeds min num failures, during interval", + "do not exceeds min num failures AND failure rate, during interval", ], - indirect=True, ) - def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): - fd = CommandFailureDetector(5, 0.3) + def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min_num_failures, failure_rate_threshold): + fd = CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - sleep(0.4) + fd.register_command_execution(('GET', 'key')) + fd.register_failure(Exception(), ('GET', 'key')) + sleep(0.16) + fd.register_command_execution(('GET', 'key')) + fd.register_command_execution(('GET', 'key')) + fd.register_command_execution(('GET', 'key')) + fd.register_failure(Exception(), ('GET', 'key')) + sleep(0.16) + fd.register_command_execution(('GET', 'key')) + fd.register_failure(Exception(), ('GET', 'key')) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - - assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + # 2 more failure as last one already refreshed timer + fd.register_command_execution(('GET', 'key')) + fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(('GET', 'key')) + fd.register_failure(Exception(), ('GET', 'key')) assert mock_db.circuit.state == CBState.OPEN - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - ], - indirect=True, - ) - def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 6c32cf0699..d49aca5605 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -13,7 +13,7 @@ from redis.multidb.client import MultiDBClient from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL from redis.multidb.event import ActiveDatabaseChanged -from redis.multidb.failure_detector import DEFAULT_FAILURES_THRESHOLD +from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.client import Redis @@ -74,7 +74,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) - failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + min_num_failures = request.param.get('min_num_failures', DEFAULT_MIN_NUM_FAILURES) command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) # Retry configuration different for health checks as initial health check require more time in case @@ -116,7 +116,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen client_class=client_class, databases_config=db_configs, command_retry=command_retry, - failure_threshold=failure_threshold, + min_num_failures=min_num_failures, health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index d641fb65bb..327676a78d 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -41,13 +41,13 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(15) + sleep(20) @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -96,8 +96,8 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2, "health_check_interval": 20}, - {"client_class": RedisCluster, "failure_threshold": 2, "health_check_interval": 20}, + {"client_class": Redis, "min_num_failures": 2, "health_check_interval": 20}, + {"client_class": RedisCluster, "min_num_failures": 2, "health_check_interval": 20}, ], ids=["standalone", "cluster"], indirect=True @@ -156,8 +156,8 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -214,8 +214,8 @@ def callback(): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -273,8 +273,8 @@ def callback(): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -329,8 +329,8 @@ def callback(pipe: Pipeline): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True @@ -389,8 +389,8 @@ def handler(message): @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "failure_threshold": 2}, - {"client_class": RedisCluster, "failure_threshold": 2}, + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], indirect=True From f5231ee5a012e02ae0a26ebd7a51ff365c688fc0 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 25 Sep 2025 13:14:11 +0300 Subject: [PATCH 26/50] Decreased timeouts --- tests/test_asyncio/test_scenario/conftest.py | 2 +- tests/test_scenario/test_active_active.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index c152e51fb5..e28904c1ae 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -98,7 +98,7 @@ async def teardown(): if isinstance(client.command_executor.active_database.client, Redis): await client.command_executor.active_database.client.connection_pool.disconnect() - await asyncio.sleep(20) + await asyncio.sleep(10) yield client, listener, endpoint_config await teardown() \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 327676a78d..cca84a9bb1 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -41,7 +41,7 @@ class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(20) + sleep(10) @pytest.mark.parametrize( "r_multi_db", From 413ea863ae681d24fa8023e567c652eca1c6fdaa Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 25 Sep 2025 16:34:01 +0300 Subject: [PATCH 27/50] Added missing fixture --- tests/test_scenario/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index d49aca5605..8d6ee70f0a 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -56,6 +56,9 @@ def get_endpoints_config(endpoint_name: str): f"Failed to load endpoints config file: {endpoints_config}" ) from e +@pytest.fixture() +def endpoints_config(endpoint_name: str): + return get_endpoints_config(endpoint_name) @pytest.fixture() def fault_injector_client(): From 3ed14e469d4d8c3fb7e3b4e06630167bc18b0435 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 30 Sep 2025 15:40:11 +0300 Subject: [PATCH 28/50] Fixed None exception --- redis/asyncio/multidb/client.py | 5 +++-- tests/test_asyncio/test_scenario/conftest.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 71fbd9c2e0..db1814f661 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -280,7 +280,7 @@ async def _check_databases_health( if on_error: on_error(result.original_exception) - async def _check_db_health(self, database: AsyncDatabase,) -> bool: + async def _check_db_health(self, database: AsyncDatabase) -> bool: """ Runs health checks on the given database until first failure. """ @@ -307,7 +307,8 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) async def aclose(self): - await self.command_executor.active_database.client.aclose() + if self.command_executor.active_database: + await self.command_executor.active_database.client.aclose() def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index e28904c1ae..88313afdd6 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -95,7 +95,10 @@ async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActive async def teardown(): await client.aclose() - if isinstance(client.command_executor.active_database.client, Redis): + if ( + client.command_executor.active_database + and isinstance(client.command_executor.active_database.client, Redis) + ): await client.command_executor.active_database.client.connection_pool.disconnect() await asyncio.sleep(10) From 71fc90fe16f21c38f2d87096607ff5f6f5a1192d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 10:52:29 +0300 Subject: [PATCH 29/50] Codestyle changes --- redis/asyncio/client.py | 2 +- redis/asyncio/cluster.py | 4 +- redis/asyncio/connection.py | 2 +- redis/asyncio/http/http_client.py | 118 ++++-- redis/asyncio/multidb/client.py | 130 +++--- redis/asyncio/multidb/command_executor.py | 107 +++-- redis/asyncio/multidb/config.py | 78 +++- redis/asyncio/multidb/database.py | 14 +- redis/asyncio/multidb/event.py | 19 +- redis/asyncio/multidb/failover.py | 27 +- redis/asyncio/multidb/failure_detector.py | 6 +- redis/asyncio/multidb/healthcheck.py | 39 +- redis/background.py | 50 +-- redis/client.py | 16 +- redis/data_structure.py | 14 +- redis/event.py | 43 +- redis/http/http_client.py | 90 ++-- redis/multidb/circuit.py | 21 +- redis/multidb/client.py | 107 +++-- redis/multidb/command_executor.py | 97 +++-- redis/multidb/config.py | 72 +++- redis/multidb/database.py | 24 +- redis/multidb/event.py | 18 +- redis/multidb/exception.py | 5 +- redis/multidb/failover.py | 24 +- redis/multidb/failure_detector.py | 29 +- redis/multidb/healthcheck.py | 36 +- redis/retry.py | 14 +- redis/utils.py | 4 +- tests/conftest.py | 1 + tests/test_asyncio/test_multidb/conftest.py | 137 +++--- .../test_asyncio/test_multidb/test_client.py | 399 +++++++++++------- .../test_multidb/test_command_executor.py | 96 +++-- .../test_asyncio/test_multidb/test_config.py | 103 +++-- .../test_multidb/test_failover.py | 87 ++-- .../test_multidb/test_failure_detector.py | 81 ++-- .../test_multidb/test_healthcheck.py | 166 +++++--- .../test_multidb/test_pipeline.py | 272 +++++++----- tests/test_asyncio/test_scenario/conftest.py | 160 +++---- .../test_scenario/test_active_active.py | 275 +++++++----- tests/test_background.py | 13 +- tests/test_data_structure.py | 73 ++-- tests/test_event.py | 24 +- tests/test_http/test_http_client.py | 87 +++- tests/test_multidb/conftest.py | 143 ++++--- tests/test_multidb/test_circuit.py | 12 +- tests/test_multidb/test_client.py | 392 ++++++++++------- tests/test_multidb/test_command_executor.py | 97 +++-- tests/test_multidb/test_config.py | 94 +++-- tests/test_multidb/test_failover.py | 79 ++-- tests/test_multidb/test_failure_detector.py | 69 ++- tests/test_multidb/test_healthcheck.py | 158 ++++--- tests/test_multidb/test_pipeline.py | 281 +++++++----- tests/test_scenario/conftest.py | 145 ++++--- tests/test_scenario/test_active_active.py | 263 ++++++------ 55 files changed, 3071 insertions(+), 1846 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index cf93e9c842..ab5a3ac0bd 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1239,7 +1239,7 @@ async def run( *, exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, - pubsub = None + pubsub=None, ) -> None: """Process pub/sub messages using registered callbacks. diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 78805021e0..9810654626 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2255,7 +2255,9 @@ async def _reinitialize_on_error(self, error): self.reinitialize_counter = 0 else: if type(error) == MovedError: - self._pipe.cluster_client.nodes_manager.update_moved_exception(error) + self._pipe.cluster_client.nodes_manager.update_moved_exception( + error + ) self._executing = False diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ce437599c9..c79ed690d9 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1283,7 +1283,7 @@ async def update_active_connections_for_reconnect(self): """ async with self._lock: for conn in self._in_use_connections: - conn.mark_for_reconnect() + conn.mark_for_reconnect() async def aclose(self) -> None: """Close the pool, disconnecting all connections""" diff --git a/redis/asyncio/http/http_client.py b/redis/asyncio/http/http_client.py index 8f746b0a8b..51e3ba9226 100644 --- a/redis/asyncio/http/http_client.py +++ b/redis/asyncio/http/http_client.py @@ -8,15 +8,18 @@ DEFAULT_TIMEOUT = 30.0 RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + class AsyncHTTPClient(ABC): @abstractmethod async def get( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: """ Invoke HTTP GET request.""" @@ -26,10 +29,12 @@ async def get( async def delete( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: """ Invoke HTTP DELETE request.""" @@ -41,10 +46,12 @@ async def post( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: """ Invoke HTTP POST request.""" @@ -56,10 +63,12 @@ async def put( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: """ Invoke HTTP PUT request.""" @@ -71,10 +80,12 @@ async def patch( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: """ Invoke HTTP PATCH request.""" @@ -85,7 +96,9 @@ async def request( self, method: str, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, timeout: Optional[float] = None, @@ -94,15 +107,13 @@ async def request( Invoke HTTP request with given method.""" pass + class AsyncHTTPClientWrapper(AsyncHTTPClient): """ An async wrapper around sync HTTP client with thread pool execution. """ - def __init__( - self, - client: HttpClient, - max_workers: int = 10 - ) -> None: + + def __init__(self, client: HttpClient, max_workers: int = 10) -> None: """ Initialize a new HTTP client instance. @@ -121,31 +132,37 @@ def __init__( async def get( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: loop = asyncio.get_event_loop() return await loop.run_in_executor( - self._executor, - self.client.get, - path, params, headers, timeout, expect_json + self._executor, self.client.get, path, params, headers, timeout, expect_json ) async def delete( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self._executor, self.client.delete, - path, params, headers, timeout, expect_json + path, + params, + headers, + timeout, + expect_json, ) async def post( @@ -153,16 +170,24 @@ async def post( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self._executor, self.client.post, - path, json_body, data, params, headers, timeout, expect_json + path, + json_body, + data, + params, + headers, + timeout, + expect_json, ) async def put( @@ -170,16 +195,24 @@ async def put( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self._executor, self.client.put, - path, json_body, data, params, headers, timeout, expect_json + path, + json_body, + data, + params, + headers, + timeout, + expect_json, ) async def patch( @@ -187,23 +220,33 @@ async def patch( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: loop = asyncio.get_event_loop() return await loop.run_in_executor( self._executor, self.client.patch, - path, json_body, data, params, headers, timeout, expect_json + path, + json_body, + data, + params, + headers, + timeout, + expect_json, ) async def request( self, method: str, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, timeout: Optional[float] = None, @@ -212,5 +255,10 @@ async def request( return await loop.run_in_executor( self._executor, self.client.request, - method, path, params, headers, body, timeout + method, + path, + params, + headers, + body, + timeout, ) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index db1814f661..0354733b6d 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -16,11 +16,13 @@ logger = logging.getLogger(__name__) + class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): """ Client that operates on multiple logical Redis databases. Should be used in Active-Active database setups. """ + def __init__(self, config: MultiDbConfig): self._databases = config.databases() self._health_checks = config.default_health_checks() @@ -30,16 +32,18 @@ def __init__(self, config: MultiDbConfig): self._health_check_interval = config.health_check_interval self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( - config.health_check_probes, - config.health_check_delay + config.health_check_probes, config.health_check_delay ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: self._failure_detectors.extend(config.failure_detectors) - self._failover_strategy = config.default_failover_strategy() \ - if config.failover_strategy is None else config.failover_strategy + self._failover_strategy = ( + config.default_failover_strategy() + if config.failover_strategy is None + else config.failover_strategy + ) self._failover_strategy.set_databases(self._databases) self._auto_fallback_interval = config.auto_fallback_interval self._event_dispatcher = config.event_dispatcher @@ -80,6 +84,7 @@ async def initialize(self): """ Perform initialization of databases to define their initial state. """ + async def raise_exception_on_failed_hc(error): raise error @@ -87,10 +92,12 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - self._recurring_hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( - self._health_check_interval, - self._check_databases_health, - )) + self._recurring_hc_task = asyncio.create_task( + self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + ) + ) is_active_db_found = False @@ -104,7 +111,9 @@ async def raise_exception_on_failed_hc(error): is_active_db_found = True if not is_active_db_found: - raise NoValidDatabaseException('Initial connection failed - no active database found') + raise NoValidDatabaseException( + "Initial connection failed - no active database found" + ) self.initialized = True @@ -126,7 +135,7 @@ async def set_active_database(self, database: AsyncDatabase) -> None: break if not exists: - raise ValueError('Given database is not a member of database list') + raise ValueError("Given database is not a member of database list") await self._check_db_health(database) @@ -135,7 +144,9 @@ async def set_active_database(self, database: AsyncDatabase) -> None: await self.command_executor.set_active_database(database) return - raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + raise NoValidDatabaseException( + "Cannot set active database, database is unhealthy" + ) async def add_database(self, database: AsyncDatabase): """ @@ -143,7 +154,7 @@ async def add_database(self, database: AsyncDatabase): """ for existing_db, _ in self._databases: if existing_db == database: - raise ValueError('Given database already exists') + raise ValueError("Given database already exists") await self._check_db_health(database) @@ -151,8 +162,13 @@ async def add_database(self, database: AsyncDatabase): self._databases.add(database, database.weight) await self._change_active_database(database, highest_weighted_db) - async def _change_active_database(self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase): - if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + async def _change_active_database( + self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase + ): + if ( + new_database.weight > highest_weight_database.weight + and new_database.circuit.state == CBState.CLOSED + ): await self.command_executor.set_active_database(new_database) async def remove_database(self, database: AsyncDatabase): @@ -162,7 +178,10 @@ async def remove_database(self, database: AsyncDatabase): weight = self._databases.remove(database) highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] - if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + if ( + highest_weight <= weight + and highest_weighted_db.circuit.state == CBState.CLOSED + ): await self.command_executor.set_active_database(highest_weighted_db) async def update_database_weight(self, database: AsyncDatabase, weight: float): @@ -177,7 +196,7 @@ async def update_database_weight(self, database: AsyncDatabase, weight: float): break if not exists: - raise ValueError('Given database is not a member of database list') + raise ValueError("Given database is not a member of database list") highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] self._databases.update_weight(database, weight) @@ -213,12 +232,12 @@ def pipeline(self): return Pipeline(self) async def transaction( - self, - func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], - *watches: KeyT, - shard_hint: Optional[str] = None, - value_from_callable: bool = False, - watch_delay: Optional[float] = None, + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, ): """ Executes callable as transaction. @@ -246,15 +265,18 @@ async def pubsub(self, **kwargs): return PubSub(self, **kwargs) async def _check_databases_health( - self, - on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, ): """ Runs health checks as a recurring task. Runs health checks against all databases. """ try: - self._hc_tasks = [asyncio.create_task(self._check_db_health(database)) for database, _ in self._databases] + self._hc_tasks = [ + asyncio.create_task(self._check_db_health(database)) + for database, _ in self._databases + ] results = await asyncio.wait_for( asyncio.gather( *self._hc_tasks, @@ -273,8 +295,8 @@ async def _check_databases_health( unhealthy_db.circuit.state = CBState.OPEN logger.exception( - 'Health check failed, due to exception', - exc_info=result.original_exception + "Health check failed, due to exception", + exc_info=result.original_exception, ) if on_error: @@ -285,7 +307,9 @@ async def _check_db_health(self, database: AsyncDatabase) -> bool: Runs health checks on the given database until first failure. """ # Health check will setup circuit state - is_healthy = await self._health_check_policy.execute(self._health_checks, database) + is_healthy = await self._health_check_policy.execute( + self._health_checks, database + ) if not is_healthy: if database.circuit.state != CBState.OPEN: @@ -296,11 +320,15 @@ async def _check_db_health(self, database: AsyncDatabase) -> bool: return is_healthy - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback( + self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState + ): loop = asyncio.get_running_loop() if new_state == CBState.HALF_OPEN: - self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database)) + self._half_open_state_task = asyncio.create_task( + self._check_db_health(circuit.database) + ) return if old_state == CBState.CLOSED and new_state == CBState.OPEN: @@ -310,13 +338,16 @@ async def aclose(self): if self.command_executor.active_database: await self.command_executor.active_database.client.aclose() + def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN + class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands): """ Pipeline implementation for multiple logical Redis databases. """ + def __init__(self, client: MultiDBClient): self._command_stack = [] self._client = client @@ -370,17 +401,21 @@ def execute_command(self, *args, **kwargs): async def execute(self) -> List[Any]: """Execute all the commands in the current pipeline""" if not self._client.initialized: - await self._client.initialize() + await self._client.initialize() try: - return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + return await self._client.command_executor.execute_pipeline( + tuple(self._command_stack) + ) finally: await self.reset() + class PubSub: """ PubSub object for multi database client. """ + def __init__(self, client: MultiDBClient, **kwargs): """Initialize the PubSub object for a multi-database client. @@ -399,14 +434,16 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: await self.aclose() async def aclose(self): - return await self._client.command_executor.execute_pubsub_method('aclose') + return await self._client.command_executor.execute_pubsub_method("aclose") @property def subscribed(self) -> bool: return self._client.command_executor.active_pubsub.subscribed async def execute_command(self, *args: EncodableT): - return await self._client.command_executor.execute_pubsub_method('execute_command', *args) + return await self._client.command_executor.execute_pubsub_method( + "execute_command", *args + ) async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): """ @@ -417,9 +454,7 @@ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): ``listen()``. """ return await self._client.command_executor.execute_pubsub_method( - 'psubscribe', - *args, - **kwargs + "psubscribe", *args, **kwargs ) async def punsubscribe(self, *args: ChannelT): @@ -428,8 +463,7 @@ async def punsubscribe(self, *args: ChannelT): all patterns. """ return await self._client.command_executor.execute_pubsub_method( - 'punsubscribe', - *args + "punsubscribe", *args ) async def subscribe(self, *args: ChannelT, **kwargs: Callable): @@ -441,9 +475,7 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): ``get_message()``. """ return await self._client.command_executor.execute_pubsub_method( - 'subscribe', - *args, - **kwargs + "subscribe", *args, **kwargs ) async def unsubscribe(self, *args): @@ -452,8 +484,7 @@ async def unsubscribe(self, *args): all channels """ return await self._client.command_executor.execute_pubsub_method( - 'unsubscribe', - *args + "unsubscribe", *args ) async def get_message( @@ -467,8 +498,9 @@ async def get_message( number or None to wait indefinitely. """ return await self._client.command_executor.execute_pubsub_method( - 'get_message', - ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + "get_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, ) async def run( @@ -491,7 +523,5 @@ async def run( >>> await task """ return await self._client.command_executor.execute_pubsub_run( - exception_handler=exception_handler, - sleep_time=poll_timeout, - pubsub=self - ) \ No newline at end of file + exception_handler=exception_handler, sleep_time=poll_timeout, pubsub=self + ) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 2526c4ed9e..de9dd62a85 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -6,10 +6,19 @@ from redis.asyncio import RedisCluster from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database -from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ - ResubscribeOnActiveDatabaseChanged, CloseConnectionOnActiveDatabaseChanged -from redis.asyncio.multidb.failover import AsyncFailoverStrategy, FailoverStrategyExecutor, DefaultFailoverStrategyExecutor, \ - DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.asyncio.multidb.event import ( + AsyncActiveDatabaseChanged, + RegisterCommandFailure, + ResubscribeOnActiveDatabaseChanged, + CloseConnectionOnActiveDatabaseChanged, +) +from redis.asyncio.multidb.failover import ( + AsyncFailoverStrategy, + FailoverStrategyExecutor, + DefaultFailoverStrategyExecutor, + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, +) from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry @@ -20,7 +29,6 @@ class AsyncCommandExecutor(CommandExecutor): - @property @abstractmethod def databases(self) -> Databases: @@ -89,7 +97,9 @@ async def execute_pipeline(self, command_stack: tuple): pass @abstractmethod - async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): """Executes a transaction block wrapped in callback.""" pass @@ -106,15 +116,15 @@ async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): def __init__( - self, - failure_detectors: List[AsyncFailureDetector], - databases: Databases, - command_retry: Retry, - failover_strategy: AsyncFailoverStrategy, - event_dispatcher: EventDispatcherInterface, - failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, - failover_delay: float = DEFAULT_FAILOVER_DELAY, - auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ Initialize the DefaultCommandExecutor instance. @@ -138,9 +148,7 @@ def __init__( self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy_executor = DefaultFailoverStrategyExecutor( - failover_strategy, - failover_attempts, - failover_delay + failover_strategy, failover_attempts, failover_delay ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None @@ -170,7 +178,12 @@ async def set_active_database(self, database: AsyncDatabase) -> None: if old_active is not None and old_active is not database: await self._event_dispatcher.dispatch_async( - AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + AsyncActiveDatabaseChanged( + old_active, + self._active_database, + self, + **self._active_pubsub_kwargs, + ) ) @property @@ -199,7 +212,9 @@ def pubsub(self, **kwargs): async def execute_command(self, *args, **options): async def callback(): - response = await self._active_database.client.execute_command(*args, **options) + response = await self._active_database.client.execute_command( + *args, **options + ) await self._register_command_execution(args) return response @@ -218,12 +233,12 @@ async def callback(): return await self._execute_with_failure_detection(callback, command_stack) async def execute_transaction( - self, - func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], - *watches: KeyT, - shard_hint: Optional[str] = None, - value_from_callable: bool = False, - watch_delay: Optional[float] = None, + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, ): async def callback(): response = await self._active_database.client.transaction( @@ -231,7 +246,7 @@ async def callback(): *watches, shard_hint=shard_hint, value_from_callable=value_from_callable, - watch_delay=watch_delay + watch_delay=watch_delay, ) await self._register_command_execution(()) return response @@ -257,10 +272,13 @@ async def callback(): return await self._execute_with_failure_detection(callback) - async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + async def _execute_with_failure_detection( + self, callback: Callable, cmds: tuple = () + ): """ Execute a commands execution callback with failure detection. """ + async def wrapper(): # On each retry we need to check active database as it might change. await self._check_active_database() @@ -276,18 +294,22 @@ async def _check_active_database(self): Checks if active a database needs to be updated. """ if ( - self._active_database is None - or self._active_database.circuit.state != CBState.CLOSED - or ( - self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL - and self._next_fallback_attempt <= datetime.now() - ) + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) ): - await self.set_active_database(await self._failover_strategy_executor.execute()) + await self.set_active_database( + await self._failover_strategy_executor.execute() + ) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): - await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + await self._event_dispatcher.dispatch_async( + AsyncOnCommandsFailEvent(args, error) + ) async def _register_command_execution(self, cmd: tuple): for detector in self._failure_detectors: @@ -300,7 +322,12 @@ def _setup_event_dispatcher(self): failure_listener = RegisterCommandFailure(self._failure_detectors) resubscribe_listener = ResubscribeOnActiveDatabaseChanged() close_connection_listener = CloseConnectionOnActiveDatabaseChanged() - self._event_dispatcher.register_listeners({ - AsyncOnCommandsFailEvent: [failure_listener], - AsyncActiveDatabaseChanged: [close_connection_listener, resubscribe_listener], - }) \ No newline at end of file + self._event_dispatcher.register_listeners( + { + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [ + close_connection_listener, + resubscribe_listener, + ], + } + ) diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index af2029c110..7e114eff1d 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -5,24 +5,48 @@ from redis.asyncio import ConnectionPool, Redis, RedisCluster from redis.asyncio.multidb.database import Databases, Database -from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_DELAY, \ - DEFAULT_FAILOVER_ATTEMPTS -from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper -from redis.asyncio.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY +from redis.asyncio.multidb.failover import ( + AsyncFailoverStrategy, + WeightBasedFailoverStrategy, + DEFAULT_FAILOVER_DELAY, + DEFAULT_FAILOVER_ATTEMPTS, +) +from redis.asyncio.multidb.failure_detector import ( + AsyncFailureDetector, + FailureDetectorAsyncWrapper, +) +from redis.asyncio.multidb.healthcheck import ( + HealthCheck, + EchoHealthCheck, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_DELAY, + HealthCheckPolicies, + DEFAULT_HEALTH_CHECK_POLICY, +) from redis.asyncio.retry import Retry from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter, DEFAULT_GRACE_PERIOD -from redis.multidb.failure_detector import CommandFailureDetector, DEFAULT_MIN_NUM_FAILURES, \ - DEFAULT_FAILURE_RATE_THRESHOLD, DEFAULT_FAILURES_DETECTION_WINDOW +from redis.multidb.circuit import ( + CircuitBreaker, + PBCircuitBreakerAdapter, + DEFAULT_GRACE_PERIOD, +) +from redis.multidb.failure_detector import ( + CommandFailureDetector, + DEFAULT_MIN_NUM_FAILURES, + DEFAULT_FAILURE_RATE_THRESHOLD, + DEFAULT_FAILURES_DETECTION_WINDOW, +) DEFAULT_AUTO_FALLBACK_INTERVAL = 120 + def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() + @dataclass class DatabaseConfig: """ @@ -48,6 +72,7 @@ class DatabaseConfig: default_circuit_breaker: Generates and returns a default CircuitBreaker instance adapted for use. """ + weight: float = 1.0 client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None @@ -60,6 +85,7 @@ def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) + @dataclass class MultiDbConfig: """ @@ -100,6 +126,7 @@ class MultiDbConfig: Provides the default failover strategy used for handling failover scenarios with defined retry and backoff configurations. """ + databases_config: List[DatabaseConfig] client_class: Type[Union[Redis, RedisCluster]] = Redis command_retry: Retry = Retry( @@ -118,7 +145,9 @@ class MultiDbConfig: failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL - event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + event_dispatcher: EventDispatcherInterface = field( + default_factory=default_event_dispatcher + ) def databases(self) -> Databases: databases = WeightedList() @@ -126,26 +155,37 @@ def databases(self) -> Databases: for database_config in self.databases_config: # The retry object is not used in the lower level clients, so we can safely remove it. # We rely on command_retry in terms of global retries. - database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + database_config.client_kwargs.update( + {"retry": Retry(retries=0, backoff=NoBackoff())} + ) if database_config.from_url: - client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + client = self.client_class.from_url( + database_config.from_url, **database_config.client_kwargs + ) elif database_config.from_pool: - database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) - client = self.client_class.from_pool(connection_pool=database_config.from_pool) + database_config.from_pool.set_retry( + Retry(retries=0, backoff=NoBackoff()) + ) + client = self.client_class.from_pool( + connection_pool=database_config.from_pool + ) else: client = self.client_class(**database_config.client_kwargs) - circuit = database_config.default_circuit_breaker() \ - if database_config.circuit is None else database_config.circuit + circuit = ( + database_config.default_circuit_breaker() + if database_config.circuit is None + else database_config.circuit + ) databases.add( Database( client=client, circuit=circuit, weight=database_config.weight, - health_check_url=database_config.health_check_url + health_check_url=database_config.health_check_url, ), - database_config.weight + database_config.weight, ) return databases @@ -156,7 +196,7 @@ def default_failure_detectors(self) -> List[AsyncFailureDetector]: CommandFailureDetector( min_num_failures=self.min_num_failures, failure_rate_threshold=self.failure_rate_threshold, - failure_detection_window=self.failures_detection_window + failure_detection_window=self.failures_detection_window, ) ), ] @@ -167,4 +207,4 @@ def default_health_checks(self) -> List[HealthCheck]: ] def default_failover_strategy(self) -> AsyncFailoverStrategy: - return WeightBasedFailoverStrategy() \ No newline at end of file + return WeightBasedFailoverStrategy() diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py index 6afbbbf5ea..fd91991d60 100644 --- a/redis/asyncio/multidb/database.py +++ b/redis/asyncio/multidb/database.py @@ -10,6 +10,7 @@ class AsyncDatabase(AbstractDatabase): """Database with an underlying asynchronous redis client.""" + @property @abstractmethod def client(self) -> Union[Redis, RedisCluster]: @@ -34,15 +35,17 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass + Databases = WeightedList[tuple[AsyncDatabase, Number]] + class Database(BaseDatabase, AsyncDatabase): def __init__( - self, - client: Union[Redis, RedisCluster], - circuit: CircuitBreaker, - weight: float, - health_check_url: Optional[str] = None, + self, + client: Union[Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, ): self._client = client self._cb = circuit @@ -64,4 +67,3 @@ def circuit(self) -> CircuitBreaker: @circuit.setter def circuit(self, circuit: CircuitBreaker): self._cb = circuit - diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py index 9f6f463a0f..ae25f1e37c 100644 --- a/redis/asyncio/multidb/event.py +++ b/redis/asyncio/multidb/event.py @@ -10,12 +10,13 @@ class AsyncActiveDatabaseChanged: """ Event fired when an async active database has been changed. """ + def __init__( - self, - old_database: AsyncDatabase, - new_database: AsyncDatabase, - command_executor, - **kwargs + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs, ): self._old_database = old_database self._new_database = new_database @@ -38,10 +39,12 @@ def command_executor(self): def kwargs(self): return self._kwargs + class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): """ Re-subscribe the currently active pub / sub to a new active database. """ + async def listen(self, event: AsyncActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub @@ -54,10 +57,12 @@ async def listen(self, event: AsyncActiveDatabaseChanged): event.command_executor.active_pubsub = new_pubsub await old_pubsub.aclose() + class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface): """ Close connection to the old active database. """ + async def listen(self, event: AsyncActiveDatabaseChanged): await event.old_database.client.aclose() @@ -65,13 +70,15 @@ async def listen(self, event: AsyncActiveDatabaseChanged): await event.old_database.client.connection_pool.update_active_connections_for_reconnect() await event.old_database.client.connection_pool.disconnect() + class RegisterCommandFailure(AsyncEventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. """ + def __init__(self, failure_detectors: List[AsyncFailureDetector]): self._failure_detectors = failure_detectors async def listen(self, event: AsyncOnCommandsFailEvent) -> None: for failure_detector in self._failure_detectors: - await failure_detector.register_failure(event.exception, event.commands) \ No newline at end of file + await failure_detector.register_failure(event.exception, event.commands) diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index 997b7941c4..8fbcf66955 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -4,13 +4,16 @@ from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.multidb.circuit import State as CBState from redis.data_structure import WeightedList -from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) DEFAULT_FAILOVER_ATTEMPTS = 10 DEFAULT_FAILOVER_DELAY = 12 -class AsyncFailoverStrategy(ABC): +class AsyncFailoverStrategy(ABC): @abstractmethod async def database(self) -> AsyncDatabase: """Select the database according to the strategy.""" @@ -21,8 +24,8 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass -class FailoverStrategyExecutor(ABC): +class FailoverStrategyExecutor(ABC): @property @abstractmethod def failover_attempts(self) -> int: @@ -46,10 +49,12 @@ async def execute(self) -> AsyncDatabase: """Execute the failover strategy.""" pass + class WeightBasedFailoverStrategy(AsyncFailoverStrategy): """ Failover strategy based on database weights. """ + def __init__(self): self._databases = WeightedList() @@ -58,20 +63,22 @@ async def database(self) -> AsyncDatabase: if database.circuit.state == CBState.CLOSED: return database - raise NoValidDatabaseException('No valid database available for communication') + raise NoValidDatabaseException("No valid database available for communication") def set_databases(self, databases: Databases) -> None: self._databases = databases + class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): """ Executes given failover strategy. """ + def __init__( - self, - strategy: AsyncFailoverStrategy, - failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, - failover_delay: float = DEFAULT_FAILOVER_DELAY, + self, + strategy: AsyncFailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, ): self._strategy = strategy self._failover_attempts = failover_attempts @@ -93,7 +100,7 @@ def strategy(self) -> AsyncFailoverStrategy: async def execute(self) -> AsyncDatabase: try: - database = await self._strategy.database() + database = await self._strategy.database() self._reset() return database except NoValidDatabaseException as e: @@ -115,4 +122,4 @@ async def execute(self) -> AsyncDatabase: def _reset(self) -> None: self._next_attempt_ts = 0 - self._failover_counter = 0 \ No newline at end of file + self._failover_counter = 0 diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py index e6d257e941..9c6b61f591 100644 --- a/redis/asyncio/multidb/failure_detector.py +++ b/redis/asyncio/multidb/failure_detector.py @@ -2,8 +2,8 @@ from redis.multidb.failure_detector import FailureDetector -class AsyncFailureDetector(ABC): +class AsyncFailureDetector(ABC): @abstractmethod async def register_failure(self, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" @@ -19,10 +19,12 @@ def set_command_executor(self, command_executor) -> None: """Set the command executor for this failure.""" pass + class FailureDetectorAsyncWrapper(AsyncFailureDetector): """ Async wrapper for the failure detector. """ + def __init__(self, failure_detector: FailureDetector) -> None: self._failure_detector = failure_detector @@ -33,4 +35,4 @@ async def register_command_execution(self, cmd: tuple) -> None: self._failure_detector.register_command_execution(cmd) def set_command_executor(self, command_executor) -> None: - self._failure_detector.set_command_executor(command_executor) \ No newline at end of file + self._failure_detector.set_command_executor(command_executor) diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index d6d2d38814..efa765eff4 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -19,17 +19,19 @@ logger = logging.getLogger(__name__) -class HealthCheck(ABC): +class HealthCheck(ABC): @abstractmethod async def check_health(self, database) -> bool: """Function to determine the health status.""" pass + class HealthCheckPolicy(ABC): """ Health checks execution policy. """ + @property @abstractmethod def health_check_probes(self) -> int: @@ -47,6 +49,7 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: """Execute health checks and return database health status.""" pass + class AbstractHealthCheckPolicy(HealthCheckPolicy): def __init__(self, health_check_probes: int, health_check_delay: float): if health_check_probes < 1: @@ -66,10 +69,12 @@ def health_check_delay(self) -> float: async def execute(self, health_checks: List[HealthCheck], database) -> bool: pass + class HealthyAllPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if all health check probes are successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -80,18 +85,18 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: if not await health_check.check_health(database): return False except Exception as e: - raise UnhealthyDatabaseException( - f"Unhealthy database", database, e - ) + raise UnhealthyDatabaseException(f"Unhealthy database", database, e) if attempt < self.health_check_probes - 1: await asyncio.sleep(self._health_check_delay) return True + class HealthyMajorityPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if a majority of health check probes are successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -119,10 +124,12 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: await asyncio.sleep(self._health_check_delay) return True + class HealthyAnyPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if at least one health check probe is successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -154,22 +161,28 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: return is_healthy + class HealthCheckPolicies(Enum): HEALTHY_ALL = HealthyAllPolicy HEALTHY_MAJORITY = HealthyMajorityPolicy HEALTHY_ANY = HealthyAnyPolicy + DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + class EchoHealthCheck(HealthCheck): """ Health check based on ECHO command. """ + async def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] if isinstance(database.client, Redis): - actual_message = await database.client.execute_command("ECHO", "healthcheck") + actual_message = await database.client.execute_command( + "ECHO", "healthcheck" + ) return actual_message in expected_message else: # For a cluster checks if all nodes are healthy. @@ -182,11 +195,13 @@ async def check_health(self, database) -> bool: return True + class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ + def __init__( self, rest_api_port: int = 9443, @@ -230,7 +245,7 @@ def __init__( ca_data=ca_data, client_cert_file=client_cert_file, client_key_file=client_key_file, - client_key_password=client_key_password + client_key_password=client_key_password, ) ) self._rest_api_port = rest_api_port @@ -254,12 +269,12 @@ async def check_health(self, database) -> bool: matching_bdb = None for bdb in await self._http_client.get("/v1/bdbs"): for endpoint in bdb["endpoints"]: - if endpoint['dns_name'] == db_host: + if endpoint["dns_name"] == db_host: matching_bdb = bdb break # In case if the host was set as public IP - for addr in endpoint['addr']: + for addr in endpoint["addr"]: if addr == db_host: matching_bdb = bdb break @@ -268,9 +283,11 @@ async def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" - f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") + url = ( + f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}" + ) await self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised - return True \ No newline at end of file + return True diff --git a/redis/background.py b/redis/background.py index ce43cbfa7a..b6327b9fdd 100644 --- a/redis/background.py +++ b/redis/background.py @@ -7,6 +7,7 @@ class BackgroundScheduler: """ Schedules background tasks execution either in separate thread or in the running event loop. """ + def __init__(self): self._next_timer = None @@ -23,16 +24,11 @@ def run_once(self, delay: float, callback: Callable, *args): thread = threading.Thread( target=_start_event_loop_in_thread, args=(loop, self._call_later, delay, callback, *args), - daemon=True + daemon=True, ) thread.start() - def run_recurring( - self, - interval: float, - callback: Callable, - *args - ): + def run_recurring(self, interval: float, callback: Callable, *args): """ Runs recurring callable task with given interval in seconds. """ @@ -42,15 +38,12 @@ def run_recurring( thread = threading.Thread( target=_start_event_loop_in_thread, args=(loop, self._call_later_recurring, interval, callback, *args), - daemon=True + daemon=True, ) thread.start() async def run_recurring_async( - self, - interval: float, - coro: Callable[..., Coroutine[Any, Any, Any]], - *args + self, interval: float, coro: Callable[..., Coroutine[Any, Any, Any]], *args ): """ Runs recurring coroutine with given interval in seconds in the current event loop. @@ -69,31 +62,27 @@ def tick(): self._next_timer = loop.call_later(interval, tick) def _call_later( - self, - loop: asyncio.AbstractEventLoop, - delay: float, - callback: Callable, - *args + self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args ): self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( - self, - loop: asyncio.AbstractEventLoop, - interval: float, - callback: Callable, - *args + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args, ): self._call_later( loop, interval, self._execute_recurring, loop, interval, callback, *args ) def _execute_recurring( - self, - loop: asyncio.AbstractEventLoop, - interval: float, - callback: Callable, - *args + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args, ): """ Executes recurring callable task with given interval in seconds. @@ -105,7 +94,9 @@ def _execute_recurring( ) -def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args): +def _start_event_loop_in_thread( + event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args +): """ Starts event loop in a thread and schedule callback as soon as event loop is ready. Used to be able to schedule tasks using loop.call_later. @@ -117,6 +108,7 @@ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon event_loop.call_soon(call_soon_cb, event_loop, *args) event_loop.run_forever() + def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): """ Wraps an asynchronous function so it can be used with loop.call_later. @@ -132,4 +124,4 @@ def wrapped(): # Schedule the coroutine in the event loop asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) - return wrapped \ No newline at end of file + return wrapped diff --git a/redis/client.py b/redis/client.py index 7151bee0d6..c8e2ecac72 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1271,7 +1271,7 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, - pubsub = None, + pubsub=None, sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): @@ -1288,7 +1288,11 @@ def run_in_thread( pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler, sharded_pubsub=sharded_pubsub + pubsub, + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + sharded_pubsub=sharded_pubsub, ) thread.start() return thread @@ -1322,9 +1326,13 @@ def run(self) -> None: while self._running.is_set(): try: if not self.sharded_pubsub: - pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + pubsub.get_message( + ignore_subscribe_messages=True, timeout=sleep_time + ) else: - pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=sleep_time) + pubsub.get_sharded_message( + ignore_subscribe_messages=True, timeout=sleep_time + ) except BaseException as e: if self.exception_handler is None: raise diff --git a/redis/data_structure.py b/redis/data_structure.py index 5b0df7f017..dc91e48650 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -3,12 +3,14 @@ from redis.typing import Number -T = TypeVar('T') +T = TypeVar("T") + class WeightedList(Generic[T]): """ Thread-safe weighted list. """ + def __init__(self): self._items: List[tuple[Any, Number]] = [] self._lock = threading.RLock() @@ -36,7 +38,9 @@ def remove(self, item): return weight raise ValueError("Item not found") - def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Number]]: + def get_by_weight_range( + self, min_weight: float, max_weight: float + ) -> List[tuple[Any, Number]]: """Get all items within weight range""" with self._lock: result = [] @@ -60,7 +64,9 @@ def update_weight(self, item, new_weight: float): def __iter__(self): """Iterate in descending weight order""" with self._lock: - items_copy = self._items.copy() # Create snapshot as lock released after each 'yield' + items_copy = ( + self._items.copy() + ) # Create snapshot as lock released after each 'yield' for item, weight in items_copy: yield item, weight @@ -72,4 +78,4 @@ def __len__(self): def __getitem__(self, index) -> tuple[Any, Number]: with self._lock: item, weight = self._items[index] - return item, weight \ No newline at end of file + return item, weight diff --git a/redis/event.py b/redis/event.py index de38e1a069..bccf1fbf0d 100644 --- a/redis/event.py +++ b/redis/event.py @@ -44,8 +44,11 @@ async def dispatch_async(self, event: object): @abstractmethod def register_listeners( - self, - mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + self, + mappings: Dict[ + Type[object], + List[Union[EventListenerInterface, AsyncEventListenerInterface]], + ], ): """Register additional listeners.""" pass @@ -65,13 +68,17 @@ def __init__(self, exception: Exception, event: object): class EventDispatcher(EventDispatcherInterface): # TODO: Make dispatcher to accept external mappings. def __init__( - self, - event_listeners: Optional[Dict[Type[object], List[EventListenerInterface]]] = None, + self, + event_listeners: Optional[ + Dict[Type[object], List[EventListenerInterface]] + ] = None, ): """ Dispatcher that dispatches events to listeners associated with given event. """ - self._event_listeners_mapping: Dict[Type[object], List[EventListenerInterface]]= { + self._event_listeners_mapping: Dict[ + Type[object], List[EventListenerInterface] + ] = { AfterConnectionReleasedEvent: [ ReAuthConnectionListener(), ], @@ -109,17 +116,25 @@ async def dispatch_async(self, event: object): await listener.listen(event) def register_listeners( - self, - event_listeners: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + self, + event_listeners: Dict[ + Type[object], + List[Union[EventListenerInterface, AsyncEventListenerInterface]], + ], ): with self._lock: for event_type in event_listeners: if event_type in self._event_listeners_mapping: self._event_listeners_mapping[event_type] = list( - set(self._event_listeners_mapping[event_type] + event_listeners[event_type]) + set( + self._event_listeners_mapping[event_type] + + event_listeners[event_type] + ) ) else: - self._event_listeners_mapping[event_type] = event_listeners[event_type] + self._event_listeners_mapping[event_type] = event_listeners[ + event_type + ] class AfterConnectionReleasedEvent: @@ -257,14 +272,16 @@ def nodes(self) -> dict: def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider + class OnCommandsFailEvent: """ Event fired whenever a command fails during the execution. """ + def __init__( - self, - commands: tuple, - exception: Exception, + self, + commands: tuple, + exception: Exception, ): self._commands = commands self._exception = exception @@ -277,9 +294,11 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception + class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/http/http_client.py b/redis/http/http_client.py index af0f68f95b..4f52290c00 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -12,12 +12,7 @@ from urllib.error import URLError, HTTPError -__all__ = [ - "HttpClient", - "HttpResponse", - "HttpError", - "DEFAULT_TIMEOUT" -] +__all__ = ["HttpClient", "HttpResponse", "HttpError", "DEFAULT_TIMEOUT"] from redis.backoff import ExponentialWithJitterBackoff from redis.retry import Retry @@ -65,6 +60,7 @@ class HttpClient: """ A lightweight HTTP client for REST API calls. """ + def __init__( self, base_url: str = "", @@ -108,7 +104,11 @@ def __init__( ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client certificate and key via client_cert_file and client_key_file. """ - self.base_url = base_url.rstrip() + "/" if base_url and not base_url.endswith("/") else base_url + self.base_url = ( + base_url.rstrip() + "/" + if base_url and not base_url.endswith("/") + else base_url + ) self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} self.timeout = timeout self.retry = retry @@ -130,10 +130,12 @@ def __init__( def get( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: return self._json_call( "GET", @@ -142,16 +144,18 @@ def get( headers=headers, timeout=timeout, body=None, - expect_json=expect_json + expect_json=expect_json, ) def delete( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: return self._json_call( "DELETE", @@ -160,7 +164,7 @@ def delete( headers=headers, timeout=timeout, body=None, - expect_json=expect_json + expect_json=expect_json, ) def post( @@ -168,10 +172,12 @@ def post( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: return self._json_call( "POST", @@ -180,7 +186,7 @@ def post( headers=headers, timeout=timeout, body=self._prepare_body(json_body=json_body, data=data), - expect_json=expect_json + expect_json=expect_json, ) def put( @@ -188,10 +194,12 @@ def put( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: return self._json_call( "PUT", @@ -200,7 +208,7 @@ def put( headers=headers, timeout=timeout, body=self._prepare_body(json_body=json_body, data=data), - expect_json=expect_json + expect_json=expect_json, ) def patch( @@ -208,10 +216,12 @@ def patch( path: str, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, - expect_json: bool = True + expect_json: bool = True, ) -> Union[HttpResponse, Any]: return self._json_call( "PATCH", @@ -220,7 +230,7 @@ def patch( headers=headers, timeout=timeout, body=self._prepare_body(json_body=json_body, data=data), - expect_json=expect_json + expect_json=expect_json, ) # Low-level request @@ -228,7 +238,9 @@ def request( self, method: str, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, timeout: Optional[float] = None, @@ -265,7 +277,7 @@ def request( return self.retry.call_with_retry( lambda: self._make_request(req, context=context, timeout=timeout), lambda _: dummy_fail(), - lambda error: self._is_retryable_http_error(error) + lambda error: self._is_retryable_http_error(error), ) except HTTPError as e: # Read error body, build response, and decide on retry @@ -286,10 +298,10 @@ def request( return response def _make_request( - self, - request: Request, - context: Optional[ssl.SSLContext] = None, - timeout: Optional[float] = None, + self, + request: Request, + context: Optional[ssl.SSLContext] = None, + timeout: Optional[float] = None, ): with urlopen(request, timeout=timeout or self.timeout, context=context) as resp: raw = resp.read() @@ -312,7 +324,9 @@ def _json_call( self, method: str, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, body: Optional[Union[bytes, str]] = None, @@ -332,7 +346,9 @@ def _json_call( return resp.json() return resp - def _prepare_body(self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + def _prepare_body( + self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None + ) -> Optional[Union[bytes, str]]: if json_body is not None and data is not None: raise ValueError("Provide either json_body or data, not both.") if json_body is not None: @@ -342,17 +358,23 @@ def _prepare_body(self, json_body: Optional[Any] = None, data: Optional[Union[by def _build_url( self, path: str, - params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, ) -> str: url = urljoin(self.base_url or "", path) if params: # urlencode with doseq=True supports list/tuple values - query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) + query = urlencode( + {k: v for k, v in params.items() if v is not None}, doseq=True + ) separator = "&" if ("?" in url) else "?" url = f"{url}{separator}{query}" if query else url return url - def _prepare_headers(self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]) -> Dict[str, str]: + def _prepare_headers( + self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]] + ) -> Dict[str, str]: # Start with defaults prepared: Dict[str, str] = {} prepared.update(self._default_headers) @@ -401,4 +423,4 @@ def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes except Exception: # If decompression fails, return original bytes return content - return content \ No newline at end of file + return content diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 5757f3e6d9..3a4d90eeb3 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -6,10 +6,12 @@ DEFAULT_GRACE_PERIOD = 60 + class State(Enum): - CLOSED = 'closed' - OPEN = 'open' - HALF_OPEN = 'half-open' + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half-open" + class CircuitBreaker(ABC): @property @@ -52,10 +54,12 @@ def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]) """Callback called when the state of the circuit changes.""" pass + class BaseCircuitBreaker(CircuitBreaker): """ Base implementation of Circuit Breaker interface. """ + def __init__(self, cb: pybreaker.CircuitBreaker): self._cb = cb self._state_pb_mapper = { @@ -94,12 +98,14 @@ def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]) """Callback called when the state of the circuit changes.""" pass + class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" + def __init__( - self, - cb: Callable[[CircuitBreaker, State, State], None], - database, + self, + cb: Callable[[CircuitBreaker, State, State], None], + database, ): """ Initialize a PBListener instance. @@ -119,6 +125,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) + class PBCircuitBreakerAdapter(BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ @@ -134,4 +141,4 @@ def __init__(self, cb: pybreaker.CircuitBreaker): def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) - self._cb.add_listener(listener) \ No newline at end of file + self._cb.add_listener(listener) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 6f2022c9de..229e1b1616 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -16,11 +16,13 @@ logger = logging.getLogger(__name__) + class MultiDBClient(RedisModuleCommands, CoreCommands): """ Client that operates on multiple logical Redis databases. Should be used in Active-Active database setups. """ + def __init__(self, config: MultiDbConfig): self._databases = config.databases() self._health_checks = config.default_health_checks() @@ -30,16 +32,18 @@ def __init__(self, config: MultiDbConfig): self._health_check_interval = config.health_check_interval self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( - config.health_check_probes, - config.health_check_delay + config.health_check_probes, config.health_check_delay ) self._failure_detectors = config.default_failure_detectors() if config.failure_detectors is not None: self._failure_detectors.extend(config.failure_detectors) - self._failover_strategy = config.default_failover_strategy() \ - if config.failover_strategy is None else config.failover_strategy + self._failover_strategy = ( + config.default_failover_strategy() + if config.failover_strategy is None + else config.failover_strategy + ) self._failover_strategy.set_databases(self._databases) self._auto_fallback_interval = config.auto_fallback_interval self._event_dispatcher = config.event_dispatcher @@ -89,7 +93,9 @@ def raise_exception_on_failed_hc(error): is_active_db_found = True if not is_active_db_found: - raise NoValidDatabaseException('Initial connection failed - no active database found') + raise NoValidDatabaseException( + "Initial connection failed - no active database found" + ) self.initialized = True @@ -111,7 +117,7 @@ def set_active_database(self, database: SyncDatabase) -> None: break if not exists: - raise ValueError('Given database is not a member of database list') + raise ValueError("Given database is not a member of database list") self._check_db_health(database) @@ -120,7 +126,9 @@ def set_active_database(self, database: SyncDatabase) -> None: self.command_executor.active_database = database return - raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + raise NoValidDatabaseException( + "Cannot set active database, database is unhealthy" + ) def add_database(self, database: SyncDatabase): """ @@ -128,7 +136,7 @@ def add_database(self, database: SyncDatabase): """ for existing_db, _ in self._databases: if existing_db == database: - raise ValueError('Given database already exists') + raise ValueError("Given database already exists") self._check_db_health(database) @@ -136,8 +144,13 @@ def add_database(self, database: SyncDatabase): self._databases.add(database, database.weight) self._change_active_database(database, highest_weighted_db) - def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase): - if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + def _change_active_database( + self, new_database: SyncDatabase, highest_weight_database: SyncDatabase + ): + if ( + new_database.weight > highest_weight_database.weight + and new_database.circuit.state == CBState.CLOSED + ): self.command_executor.active_database = new_database def remove_database(self, database: Database): @@ -147,7 +160,10 @@ def remove_database(self, database: Database): weight = self._databases.remove(database) highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] - if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + if ( + highest_weight <= weight + and highest_weighted_db.circuit.state == CBState.CLOSED + ): self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: SyncDatabase, weight: float): @@ -162,7 +178,7 @@ def update_database_weight(self, database: SyncDatabase, weight: float): break if not exists: - raise ValueError('Given database is not a member of database list') + raise ValueError("Given database is not a member of database list") highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] self._databases.update_weight(database, weight) @@ -246,7 +262,9 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): } try: - for future in as_completed(futures, timeout=self._health_check_interval): + for future in as_completed( + futures, timeout=self._health_check_interval + ): try: future.result() except UnhealthyDatabaseException as e: @@ -254,26 +272,33 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): unhealthy_db.circuit.state = CBState.OPEN logger.exception( - 'Health check failed, due to exception', - exc_info=e.original_exception + "Health check failed, due to exception", + exc_info=e.original_exception, ) if on_error: on_error(e.original_exception) except TimeoutError: - raise TimeoutError("Health check execution exceeds health_check_interval") + raise TimeoutError( + "Health check execution exceeds health_check_interval" + ) - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback( + self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState + ): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return if old_state == CBState.CLOSED and new_state == CBState.OPEN: - self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + self._bg_scheduler.run_once( + DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit + ) def close(self): self.command_executor.active_database.client.close() + def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN @@ -282,6 +307,7 @@ class Pipeline(RedisModuleCommands, CoreCommands): """ Pipeline implementation for multiple logical Redis databases. """ + def __init__(self, client: MultiDBClient): self._command_stack = [] self._client = client @@ -337,14 +363,18 @@ def execute(self) -> List[Any]: self._client.initialize() try: - return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + return self._client.command_executor.execute_pipeline( + tuple(self._command_stack) + ) finally: self.reset() + class PubSub: """ PubSub object for multi database client. """ + def __init__(self, client: MultiDBClient, **kwargs): """Initialize the PubSub object for a multi-database client. @@ -369,7 +399,7 @@ def __del__(self) -> None: pass def reset(self) -> None: - return self._client.command_executor.execute_pubsub_method('reset') + return self._client.command_executor.execute_pubsub_method("reset") def close(self) -> None: self.reset() @@ -379,7 +409,9 @@ def subscribed(self) -> bool: return self._client.command_executor.active_pubsub.subscribed def execute_command(self, *args): - return self._client.command_executor.execute_pubsub_method('execute_command', *args) + return self._client.command_executor.execute_pubsub_method( + "execute_command", *args + ) def psubscribe(self, *args, **kwargs): """ @@ -389,14 +421,18 @@ def psubscribe(self, *args, **kwargs): received on that pattern rather than producing a message via ``listen()``. """ - return self._client.command_executor.execute_pubsub_method('psubscribe', *args, **kwargs) + return self._client.command_executor.execute_pubsub_method( + "psubscribe", *args, **kwargs + ) def punsubscribe(self, *args): """ Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ - return self._client.command_executor.execute_pubsub_method('punsubscribe', *args) + return self._client.command_executor.execute_pubsub_method( + "punsubscribe", *args + ) def subscribe(self, *args, **kwargs): """ @@ -406,14 +442,16 @@ def subscribe(self, *args, **kwargs): that channel rather than producing a message via ``listen()`` or ``get_message()``. """ - return self._client.command_executor.execute_pubsub_method('subscribe', *args, **kwargs) + return self._client.command_executor.execute_pubsub_method( + "subscribe", *args, **kwargs + ) def unsubscribe(self, *args): """ Unsubscribe from the supplied channels. If empty, unsubscribe from all channels """ - return self._client.command_executor.execute_pubsub_method('unsubscribe', *args) + return self._client.command_executor.execute_pubsub_method("unsubscribe", *args) def ssubscribe(self, *args, **kwargs): """ @@ -423,14 +461,18 @@ def ssubscribe(self, *args, **kwargs): when a message is received on that channel rather than producing a message via ``listen()`` or ``get_sharded_message()``. """ - return self._client.command_executor.execute_pubsub_method('ssubscribe', *args, **kwargs) + return self._client.command_executor.execute_pubsub_method( + "ssubscribe", *args, **kwargs + ) def sunsubscribe(self, *args): """ Unsubscribe from the supplied shard_channels. If empty, unsubscribe from all shard_channels """ - return self._client.command_executor.execute_pubsub_method('sunsubscribe', *args) + return self._client.command_executor.execute_pubsub_method( + "sunsubscribe", *args + ) def get_message( self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 @@ -443,8 +485,9 @@ def get_message( number, or None, to wait indefinitely. """ return self._client.command_executor.execute_pubsub_method( - 'get_message', - ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + "get_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, ) def get_sharded_message( @@ -458,8 +501,9 @@ def get_sharded_message( number, or None, to wait indefinitely. """ return self._client.command_executor.execute_pubsub_method( - 'get_sharded_message', - ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + "get_sharded_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, ) def run_in_thread( @@ -476,4 +520,3 @@ def run_in_thread( pubsub=self, sharded_pubsub=sharded_pubsub, ) - diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 481364de9a..8ba5d43e7d 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -7,16 +7,24 @@ from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState -from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged, \ - CloseConnectionOnActiveDatabaseChanged -from redis.multidb.failover import FailoverStrategy, FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, \ - DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor +from redis.multidb.event import ( + RegisterCommandFailure, + ActiveDatabaseChanged, + ResubscribeOnActiveDatabaseChanged, + CloseConnectionOnActiveDatabaseChanged, +) +from redis.multidb.failover import ( + FailoverStrategy, + FailoverStrategyExecutor, + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + DefaultFailoverStrategyExecutor, +) from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry class CommandExecutor(ABC): - @property @abstractmethod def auto_fallback_interval(self) -> float: @@ -29,10 +37,11 @@ def auto_fallback_interval(self, auto_fallback_interval: float) -> None: """Sets auto-fallback interval.""" pass + class BaseCommandExecutor(CommandExecutor): def __init__( - self, - auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): self._auto_fallback_interval = auto_fallback_interval self._next_fallback_attempt: datetime @@ -49,10 +58,12 @@ def _schedule_next_fallback(self) -> None: if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: return - self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + self._next_fallback_attempt = datetime.now() + timedelta( + seconds=self._auto_fallback_interval + ) -class SyncCommandExecutor(CommandExecutor): +class SyncCommandExecutor(CommandExecutor): @property @abstractmethod def databases(self) -> Databases: @@ -122,7 +133,9 @@ def execute_pipeline(self, command_stack: tuple): pass @abstractmethod - def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): """Executes a transaction block wrapped in callback.""" pass @@ -136,17 +149,18 @@ def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: """Executes pub/sub run in a thread.""" pass + class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): def __init__( - self, - failure_detectors: List[FailureDetector], - databases: Databases, - command_retry: Retry, - failover_strategy: FailoverStrategy, - event_dispatcher: EventDispatcherInterface, - failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, - failover_delay: float = DEFAULT_FAILOVER_DELAY, - auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + self, + failure_detectors: List[FailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: FailoverStrategy, + event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, ): """ Initialize the DefaultCommandExecutor instance. @@ -170,9 +184,7 @@ def __init__( self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy_executor = DefaultFailoverStrategyExecutor( - failover_strategy, - failover_attempts, - failover_delay + failover_strategy, failover_attempts, failover_delay ) self._event_dispatcher = event_dispatcher self._active_database: Optional[Database] = None @@ -207,7 +219,12 @@ def active_database(self, database: SyncDatabase) -> None: if old_active is not None and old_active is not database: self._event_dispatcher.dispatch( - ActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ActiveDatabaseChanged( + old_active, + self._active_database, + self, + **self._active_pubsub_kwargs, + ) ) @property @@ -242,9 +259,13 @@ def callback(): return self._execute_with_failure_detection(callback, command_stack) - def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): def callback(): - response = self._active_database.client.transaction(transaction, *watches, **options) + response = self._active_database.client.transaction( + transaction, *watches, **options + ) self._register_command_execution(()) return response @@ -278,6 +299,7 @@ def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): """ Execute a commands execution callback with failure detection. """ + def wrapper(): # On each retry we need to check active database as it might change. self._check_active_database() @@ -296,12 +318,12 @@ def _check_active_database(self): Checks if active a database needs to be updated. """ if ( - self._active_database is None - or self._active_database.circuit.state != CBState.CLOSED - or ( - self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL - and self._next_fallback_attempt <= datetime.now() - ) + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) ): self.active_database = self._failover_strategy_executor.execute() self._schedule_next_fallback() @@ -317,7 +339,12 @@ def _setup_event_dispatcher(self): failure_listener = RegisterCommandFailure(self._failure_detectors) resubscribe_listener = ResubscribeOnActiveDatabaseChanged() close_connection_listener = CloseConnectionOnActiveDatabaseChanged() - self._event_dispatcher.register_listeners({ - OnCommandsFailEvent: [failure_listener], - ActiveDatabaseChanged: [close_connection_listener, resubscribe_listener], - }) \ No newline at end of file + self._event_dispatcher.register_listeners( + { + OnCommandsFailEvent: [failure_listener], + ActiveDatabaseChanged: [ + close_connection_listener, + resubscribe_listener, + ], + } + ) diff --git a/redis/multidb/config.py b/redis/multidb/config.py index f78114f014..9ee41e394f 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,21 +9,43 @@ from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD +from redis.multidb.circuit import ( + PBCircuitBreakerAdapter, + CircuitBreaker, + DEFAULT_GRACE_PERIOD, +) from redis.multidb.database import Database, Databases -from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector, DEFAULT_MIN_NUM_FAILURES, \ - DEFAULT_FAILURES_DETECTION_WINDOW, DEFAULT_FAILURE_RATE_THRESHOLD -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_PROBES, \ - DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, HealthCheckPolicies, DEFAULT_HEALTH_CHECK_POLICY -from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, \ - DEFAULT_FAILOVER_DELAY +from redis.multidb.failure_detector import ( + FailureDetector, + CommandFailureDetector, + DEFAULT_MIN_NUM_FAILURES, + DEFAULT_FAILURES_DETECTION_WINDOW, + DEFAULT_FAILURE_RATE_THRESHOLD, +) +from redis.multidb.healthcheck import ( + HealthCheck, + EchoHealthCheck, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_DELAY, + HealthCheckPolicies, + DEFAULT_HEALTH_CHECK_POLICY, +) +from redis.multidb.failover import ( + FailoverStrategy, + WeightBasedFailoverStrategy, + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, +) from redis.retry import Retry DEFAULT_AUTO_FALLBACK_INTERVAL = 120 + def default_event_dispatcher() -> EventDispatcherInterface: return EventDispatcher() + @dataclass class DatabaseConfig: """ @@ -49,6 +71,7 @@ class DatabaseConfig: default_circuit_breaker: Generates and returns a default CircuitBreaker instance adapted for use. """ + weight: float = 1.0 client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None @@ -61,6 +84,7 @@ def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) + @dataclass class MultiDbConfig: """ @@ -102,6 +126,7 @@ class MultiDbConfig: Provides the default failover strategy used for handling failover scenarios with defined retry and backoff configurations. """ + databases_config: List[DatabaseConfig] client_class: Type[Union[Redis, RedisCluster]] = Redis command_retry: Retry = Retry( @@ -120,7 +145,9 @@ class MultiDbConfig: failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS failover_delay: float = DEFAULT_FAILOVER_DELAY auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL - event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + event_dispatcher: EventDispatcherInterface = field( + default_factory=default_event_dispatcher + ) def databases(self) -> Databases: databases = WeightedList() @@ -128,26 +155,37 @@ def databases(self) -> Databases: for database_config in self.databases_config: # The retry object is not used in the lower level clients, so we can safely remove it. # We rely on command_retry in terms of global retries. - database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + database_config.client_kwargs.update( + {"retry": Retry(retries=0, backoff=NoBackoff())} + ) if database_config.from_url: - client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + client = self.client_class.from_url( + database_config.from_url, **database_config.client_kwargs + ) elif database_config.from_pool: - database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) - client = self.client_class.from_pool(connection_pool=database_config.from_pool) + database_config.from_pool.set_retry( + Retry(retries=0, backoff=NoBackoff()) + ) + client = self.client_class.from_pool( + connection_pool=database_config.from_pool + ) else: client = self.client_class(**database_config.client_kwargs) - circuit = database_config.default_circuit_breaker() \ - if database_config.circuit is None else database_config.circuit + circuit = ( + database_config.default_circuit_breaker() + if database_config.circuit is None + else database_config.circuit + ) databases.add( Database( client=client, circuit=circuit, weight=database_config.weight, - health_check_url=database_config.health_check_url + health_check_url=database_config.health_check_url, ), - database_config.weight + database_config.weight, ) return databases @@ -157,7 +195,7 @@ def default_failure_detectors(self) -> List[FailureDetector]: CommandFailureDetector( min_num_failures=self.min_num_failures, failure_rate_threshold=self.failure_rate_threshold, - failure_detection_window=self.failures_detection_window + failure_detection_window=self.failures_detection_window, ), ] diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 9c2ffe3552..8c7d536a88 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -8,6 +8,7 @@ from redis.multidb.circuit import CircuitBreaker from redis.typing import Number + class AbstractDatabase(ABC): @property @abstractmethod @@ -33,11 +34,12 @@ def health_check_url(self, health_check_url: Optional[str]): """Set the health check URL associated with the current database.""" pass + class BaseDatabase(AbstractDatabase): def __init__( - self, - weight: float, - health_check_url: Optional[str] = None, + self, + weight: float, + health_check_url: Optional[str] = None, ): self._weight = weight self._health_check_url = health_check_url @@ -58,8 +60,10 @@ def health_check_url(self) -> Optional[str]: def health_check_url(self, health_check_url: Optional[str]): self._health_check_url = health_check_url + class SyncDatabase(AbstractDatabase): """Database with an underlying synchronous redis client.""" + @property @abstractmethod def client(self) -> Union[redis.Redis, RedisCluster]: @@ -84,15 +88,17 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass + Databases = WeightedList[tuple[SyncDatabase, Number]] + class Database(BaseDatabase, SyncDatabase): def __init__( - self, - client: Union[redis.Redis, RedisCluster], - circuit: CircuitBreaker, - weight: float, - health_check_url: Optional[str] = None, + self, + client: Union[redis.Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, ): """ Initialize a new Database instance. @@ -122,4 +128,4 @@ def circuit(self) -> CircuitBreaker: @circuit.setter def circuit(self, circuit: CircuitBreaker): - self._cb = circuit \ No newline at end of file + self._cb = circuit diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 8a76139752..e9e9827344 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -6,16 +6,18 @@ from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector + class ActiveDatabaseChanged: """ Event fired when an active database has been changed. """ + def __init__( - self, - old_database: SyncDatabase, - new_database: SyncDatabase, - command_executor, - **kwargs + self, + old_database: SyncDatabase, + new_database: SyncDatabase, + command_executor, + **kwargs, ): self._old_database = old_database self._new_database = new_database @@ -38,10 +40,12 @@ def command_executor(self): def kwargs(self): return self._kwargs + class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ Re-subscribe the currently active pub / sub to a new active database. """ + def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub @@ -55,10 +59,12 @@ def listen(self, event: ActiveDatabaseChanged): event.command_executor.active_pubsub = new_pubsub old_pubsub.close() + class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface): """ Close connection to the old active database. """ + def listen(self, event: ActiveDatabaseChanged): event.old_database.client.close() @@ -70,10 +76,12 @@ def listen(self, event: ActiveDatabaseChanged): node.redis_connection.connection_pool.update_active_connections_for_reconnect() node.redis_connection.connection_pool.disconnect() + class RegisterCommandFailure(EventListenerInterface): """ Event listener that registers command failures and passing it to the failure detectors. """ + def __init__(self, failure_detectors: List[FailureDetector]): self._failure_detectors = failure_detectors diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py index f54632cae7..8c08c4b540 100644 --- a/redis/multidb/exception.py +++ b/redis/multidb/exception.py @@ -1,6 +1,7 @@ class NoValidDatabaseException(Exception): pass + class UnhealthyDatabaseException(Exception): """Exception raised when a database is unhealthy due to an underlying exception.""" @@ -9,6 +10,8 @@ def __init__(self, message, database, original_exception): self.database = database self.original_exception = original_exception + class TemporaryUnavailableException(Exception): """Exception raised when all databases in setup are temporary unavailable.""" - pass \ No newline at end of file + + pass diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index fbbd254252..c373a3a6f0 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -4,13 +4,16 @@ from redis.data_structure import WeightedList from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) DEFAULT_FAILOVER_ATTEMPTS = 10 DEFAULT_FAILOVER_DELAY = 12 -class FailoverStrategy(ABC): +class FailoverStrategy(ABC): @abstractmethod def database(self) -> SyncDatabase: """Select the database according to the strategy.""" @@ -21,8 +24,8 @@ def set_databases(self, databases: Databases) -> None: """Set the database strategy operates on.""" pass -class FailoverStrategyExecutor(ABC): +class FailoverStrategyExecutor(ABC): @property @abstractmethod def failover_attempts(self) -> int: @@ -46,10 +49,12 @@ def execute(self) -> SyncDatabase: """Execute the failover strategy.""" pass + class WeightBasedFailoverStrategy(FailoverStrategy): """ Failover strategy based on database weights. """ + def __init__(self) -> None: self._databases = WeightedList() @@ -58,20 +63,22 @@ def database(self) -> SyncDatabase: if database.circuit.state == CBState.CLOSED: return database - raise NoValidDatabaseException('No valid database available for communication') + raise NoValidDatabaseException("No valid database available for communication") def set_databases(self, databases: Databases) -> None: self._databases = databases + class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): """ Executes given failover strategy. """ + def __init__( - self, - strategy: FailoverStrategy, - failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, - failover_delay: float = DEFAULT_FAILOVER_DELAY, + self, + strategy: FailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, ): self._strategy = strategy self._failover_attempts = failover_attempts @@ -116,4 +123,3 @@ def execute(self) -> SyncDatabase: def _reset(self) -> None: self._next_attempt_ts = 0 self._failover_counter = 0 - diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index ca657c4e52..f1be28788e 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -12,8 +12,8 @@ DEFAULT_FAILURE_RATE_THRESHOLD = 0.1 DEFAULT_FAILURES_DETECTION_WINDOW = 2 -class FailureDetector(ABC): +class FailureDetector(ABC): @abstractmethod def register_failure(self, exception: Exception, cmd: tuple) -> None: """Register a failure that occurred during command execution.""" @@ -29,16 +29,18 @@ def set_command_executor(self, command_executor) -> None: """Set the command executor for this failure.""" pass + class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. """ + def __init__( - self, - min_num_failures: int = DEFAULT_MIN_NUM_FAILURES, - failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD, - failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW, - error_types: Optional[List[Type[Exception]]] = None, + self, + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES, + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD, + failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW, + error_types: Optional[List[Type[Exception]]] = None, ) -> None: """ Initialize a new CommandFailureDetector instance. @@ -59,7 +61,9 @@ def __init__( self._error_types = error_types self._commands_executed: int = 0 self._start_time: datetime = datetime.now() - self._end_time: datetime = self._start_time + timedelta(seconds=self._failure_detection_window) + self._end_time: datetime = self._start_time + timedelta( + seconds=self._failure_detection_window + ) self._failures_count: int = 0 self._lock = threading.RLock() @@ -84,9 +88,8 @@ def register_command_execution(self, cmd: tuple) -> None: self._commands_executed += 1 def _check_threshold(self): - if ( - self._failures_count >= self._min_num_failures - and self._failures_count >= (math.ceil(self._commands_executed * self._failure_rate_threshold)) + if self._failures_count >= self._min_num_failures and self._failures_count >= ( + math.ceil(self._commands_executed * self._failure_rate_threshold) ): self._command_executor.active_database.circuit.state = CBState.OPEN self._reset() @@ -94,6 +97,8 @@ def _check_threshold(self): def _reset(self) -> None: with self._lock: self._start_time = datetime.now() - self._end_time = self._start_time + timedelta(seconds=self._failure_detection_window) + self._end_time = self._start_time + timedelta( + seconds=self._failure_detection_window + ) self._failures_count = 0 - self._commands_executed = 0 \ No newline at end of file + self._commands_executed = 0 diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 81bbec6e17..919d4d5cbd 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -17,17 +17,19 @@ logger = logging.getLogger(__name__) -class HealthCheck(ABC): +class HealthCheck(ABC): @abstractmethod def check_health(self, database) -> bool: """Function to determine the health status.""" pass + class HealthCheckPolicy(ABC): """ Health checks execution policy. """ + @property @abstractmethod def health_check_probes(self) -> int: @@ -45,6 +47,7 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: """Execute health checks and return database health status.""" pass + class AbstractHealthCheckPolicy(HealthCheckPolicy): def __init__(self, health_check_probes: int, health_check_delay: float): if health_check_probes < 1: @@ -64,10 +67,12 @@ def health_check_delay(self) -> float: def execute(self, health_checks: List[HealthCheck], database) -> bool: pass + class HealthyAllPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if all health check probes are successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -78,18 +83,18 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: if not health_check.check_health(database): return False except Exception as e: - raise UnhealthyDatabaseException( - f"Unhealthy database", database, e - ) + raise UnhealthyDatabaseException(f"Unhealthy database", database, e) if attempt < self.health_check_probes - 1: sleep(self._health_check_delay) return True + class HealthyMajorityPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if a majority of health check probes are successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -117,10 +122,12 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: sleep(self._health_check_delay) return True + class HealthyAnyPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if at least one health check probe is successful. """ + def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) @@ -152,17 +159,21 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: return is_healthy + class HealthCheckPolicies(Enum): HEALTHY_ALL = HealthyAllPolicy HEALTHY_MAJORITY = HealthyMajorityPolicy HEALTHY_ANY = HealthyAnyPolicy + DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + class EchoHealthCheck(HealthCheck): """ Health check based on ECHO command. """ + def check_health(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] @@ -173,7 +184,9 @@ def check_health(self, database) -> bool: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: - actual_message = node.redis_connection.execute_command("ECHO", "healthcheck") + actual_message = node.redis_connection.execute_command( + "ECHO", "healthcheck" + ) if actual_message not in expected_message: return False @@ -186,6 +199,7 @@ class LagAwareHealthCheck(HealthCheck): Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ + def __init__( self, rest_api_port: int = 9443, @@ -228,7 +242,7 @@ def __init__( ca_data=ca_data, client_cert_file=client_cert_file, client_key_file=client_key_file, - client_key_password=client_key_password + client_key_password=client_key_password, ) self._rest_api_port = rest_api_port self._lag_aware_tolerance = lag_aware_tolerance @@ -251,12 +265,12 @@ def check_health(self, database) -> bool: matching_bdb = None for bdb in self._http_client.get("/v1/bdbs"): for endpoint in bdb["endpoints"]: - if endpoint['dns_name'] == db_host: + if endpoint["dns_name"] == db_host: matching_bdb = bdb break # In case if the host was set as public IP - for addr in endpoint['addr']: + for addr in endpoint["addr"]: if addr == db_host: matching_bdb = bdb break @@ -265,8 +279,10 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = (f"/v1/bdbs/{matching_bdb['uid']}/availability" - f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}") + url = ( + f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}" + ) self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised diff --git a/redis/retry.py b/redis/retry.py index c61bf56e18..3873cafab5 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,7 +1,17 @@ import abc import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Iterable, + Tuple, + Type, + TypeVar, + Optional, +) from redis.exceptions import ConnectionError, TimeoutError @@ -91,7 +101,7 @@ def call_with_retry( self, do: Callable[[], T], fail: Callable[[Exception], Any], - is_retryable: Optional[Callable[[Exception], bool]] = None + is_retryable: Optional[Callable[[Exception], bool]] = None, ) -> T: """ Execute an operation that might fail and returns its result, or diff --git a/redis/utils.py b/redis/utils.py index fc69a1e825..5ae8fb25fc 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -313,14 +313,16 @@ def truncate_text(txt, max_length=100): text=txt, width=max_length, placeholder="...", break_long_words=True ) + def dummy_fail(): """ Fake function for a Retry object if you don't need to handle each failure. """ pass + async def dummy_fail_async(): """ Async fake function for a Retry object if you don't need to handle each failure. """ - pass \ No newline at end of file + pass diff --git a/tests/conftest.py b/tests/conftest.py index fc316ea720..af2681732b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -582,6 +582,7 @@ def mock_connection() -> ConnectionInterface: mock_connection = Mock(spec=ConnectionInterface) return mock_connection + @pytest.fixture() def mock_ed() -> EventDispatcherInterface: mock_ed = Mock(spec=EventDispatcherInterface) diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 7695332754..0666dc527a 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,103 +2,124 @@ import pytest -from redis.asyncio.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.config import ( + MultiDbConfig, + DatabaseConfig, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_HEALTH_CHECK_POLICY +from redis.asyncio.multidb.healthcheck import ( + HealthCheck, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_POLICY, +) from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis, ConnectionPool from redis.asyncio.multidb.database import Database, Databases + @pytest.fixture() def mock_client() -> Redis: return Mock(spec=Redis) + @pytest.fixture() def mock_cb() -> CircuitBreaker: return Mock(spec=CircuitBreaker) + @pytest.fixture() def mock_fd() -> AsyncFailureDetector: - return Mock(spec=AsyncFailureDetector) + return Mock(spec=AsyncFailureDetector) + @pytest.fixture() def mock_fs() -> AsyncFailoverStrategy: - return Mock(spec=AsyncFailoverStrategy) + return Mock(spec=AsyncFailoverStrategy) + @pytest.fixture() def mock_hc() -> HealthCheck: - return Mock(spec=HealthCheck) + return Mock(spec=HealthCheck) + @pytest.fixture() def mock_db(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() def mock_db1(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() def mock_db2(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() -def mock_multi_db_config( - request, mock_fd, mock_fs, mock_hc, mock_ed -) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) - auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) - health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) - health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) - - config = MultiDbConfig( - databases_config=[Mock(spec=DatabaseConfig)], - failure_detectors=[mock_fd], - health_check_interval=hc_interval, - health_check_delay=0.05, - health_check_policy=health_check_policy, - health_check_probes=health_check_probes, - failover_strategy=mock_fs, - auto_fallback_interval=auto_fallback_interval, - event_dispatcher=mock_ed - ) - - return config +def mock_multi_db_config(request, mock_fd, mock_fs, mock_hc, mock_ed) -> MultiDbConfig: + hc_interval = request.param.get("hc_interval", DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get( + "auto_fallback_interval", DEFAULT_AUTO_FALLBACK_INTERVAL + ) + health_check_policy = request.param.get( + "health_check_policy", DEFAULT_HEALTH_CHECK_POLICY + ) + health_check_probes = request.param.get( + "health_check_probes", DEFAULT_HEALTH_CHECK_PROBES + ) + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed, + ) + + return config def create_weighted_list(*databases) -> Databases: @@ -107,4 +128,4 @@ def create_weighted_list(*databases) -> Databases: for db in databases: dbs.add(db, db.weight) - return dbs \ No newline at end of file + return dbs diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index e2ebb89bca..76bee8b3e6 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -18,31 +18,35 @@ class TestMultiDbClient: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_against_correct_db_on_successful_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command = AsyncMock(return_value='OK1') + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED @@ -51,31 +55,43 @@ async def test_execute_command_against_correct_db_on_successful_initialization( @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) async def test_execute_command_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command = AsyncMock(return_value='OK1') - - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -84,19 +100,19 @@ async def test_execute_command_against_correct_db_and_closed_circuit( @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -112,248 +128,326 @@ async def test_execute_command_against_correct_db_on_background_health_check_det databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "OK", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "error", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "error", + "error", + ] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" await asyncio.sleep(0.15) - assert await client.set('key', 'value') == 'OK2' + assert await client.set("key", "value") == "OK2" await asyncio.sleep(0.1) - assert await client.set('key', 'value') == 'OK' + assert await client.set("key", "value") == "OK" await asyncio.sleep(0.1) - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_auto_fallback_to_highest_weight_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "healthcheck", + "healthcheck", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "healthcheck", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "healthcheck", + "healthcheck", + "healthcheck", + ] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" await asyncio.sleep(0.15) - assert await client.set('key', 'value') == 'OK2' + assert await client.set("key", "value") == "OK2" await asyncio.sleep(0.22) - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) async def test_execute_command_throws_exception_on_failed_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): - await client.set('key', 'value') + with pytest.raises( + NoValidDatabaseException, + match="Initial connection failed - no active database found", + ): + await client.set("key", "value") assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_add_database_throws_exception_on_same_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - with pytest.raises(ValueError, match='Given database already exists'): + with pytest.raises(ValueError, match="Given database already exists"): await client.add_database(mock_db) assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_add_database_makes_new_database_active( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK2' + assert await client.set("key", "value") == "OK2" assert mock_hc.check_health.call_count == 6 await client.add_database(mock_db1) assert mock_hc.check_health.call_count == 9 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_remove_highest_weighted_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 await client.remove_database(mock_db1) - assert await client.set('key', 'value') == 'OK2' + assert await client.set("key", "value") == "OK2" @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_update_database_weight_to_be_highest( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 await client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 - assert await client.set('key', 'value') == 'OK2' + assert await client.set("key", "value") == "OK2" @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_add_new_failure_detector( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" mock_multi_db_config.event_dispatcher = EventDispatcher() mock_fd = mock_multi_db_config.failure_detectors[0] # Event fired if command against mock_db1 would fail command_fail_event = AsyncOnCommandsFailEvent( - commands=('SET', 'key', 'value'), + commands=("SET", "key", "value"), exception=Exception(), ) @@ -361,12 +455,14 @@ async def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection for i in range(5): - await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) assert mock_fd.register_failure.call_count == 5 @@ -375,38 +471,44 @@ async def test_add_new_failure_detector( # Simulate failing command events that lead to a failure detection for i in range(5): - await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) assert mock_fd.register_failure.call_count == 10 assert another_fd.register_failure.call_count == 5 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_add_new_health_check( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) @@ -420,41 +522,50 @@ async def test_add_new_health_check( @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_set_active_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db.client.execute_command.return_value = 'OK' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db.client.execute_command.return_value = "OK" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set('key', 'value') == 'OK1' + assert await client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 await client.set_active_database(mock_db) - assert await client.set('key', 'value') == 'OK' + assert await client.set("key", "value") == "OK" - with pytest.raises(ValueError, match='Given database is not a member of database list'): + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): await client.set_active_database(Mock(spec=AsyncDatabase)) mock_hc.check_health.return_value = False - with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): - await client.set_active_database(mock_db1) \ No newline at end of file + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + await client.set_active_database(mock_db1) diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py index 01a8326e5a..b104b90e85 100644 --- a/tests/test_asyncio/test_multidb/test_command_executor.py +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -17,19 +17,21 @@ class TestDefaultCommandExecutor: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) - async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): - mock_db1.client.execute_command = AsyncMock(return_value='OK1') - mock_db2.client.execute_command = AsyncMock(return_value='OK2') + async def test_execute_command_on_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -37,34 +39,34 @@ async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_ databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) await executor.set_active_database(mock_db1) - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command("SET", "key", "value") == "OK1" await executor.set_active_database(mock_db2) - assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command("SET", "key", "value") == "OK2" assert mock_ed.register_listeners.call_count == 1 assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_automatically_select_active_database( - self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed ): - mock_db1.client.execute_command = AsyncMock(return_value='OK1') - mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) type(mock_fs).database = mock_selector databases = create_weighted_list(mock_db, mock_db1, mock_db2) @@ -74,34 +76,34 @@ async def test_execute_command_automatically_select_active_database( databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command("SET", "key", "value") == "OK1" mock_db1.circuit.state = CBState.OPEN - assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command("SET", "key", "value") == "OK2" assert mock_ed.register_listeners.call_count == 1 assert mock_selector.call_count == 2 assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_fallback_to_another_db_after_fallback_interval( - self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed ): - mock_db1.client.execute_command = AsyncMock(return_value='OK1') - mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) type(mock_fs).database = mock_selector databases = create_weighted_list(mock_db, mock_db1, mock_db2) @@ -112,39 +114,49 @@ async def test_execute_command_fallback_to_another_db_after_fallback_interval( failover_strategy=mock_fs, event_dispatcher=mock_ed, auto_fallback_interval=0.1, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command("SET", "key", "value") == "OK1" mock_db1.weight = 0.1 await asyncio.sleep(0.15) - assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command("SET", "key", "value") == "OK2" mock_db1.weight = 0.7 await asyncio.sleep(0.15) - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command("SET", "key", "value") == "OK1" assert mock_ed.register_listeners.call_count == 1 assert mock_selector.call_count == 3 assert mock_fd.register_command_execution.call_count == 3 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_command_fallback_to_another_db_after_failure_detection( - self, mock_db, mock_db1, mock_db2, mock_fs + self, mock_db, mock_db1, mock_db2, mock_fs ): - mock_db1.client.execute_command = AsyncMock(side_effect=['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1']) - mock_db2.client.execute_command = AsyncMock(side_effect=['OK2', ConnectionError, ConnectionError, ConnectionError]) + mock_db1.client.execute_command = AsyncMock( + side_effect=[ + "OK1", + ConnectionError, + ConnectionError, + ConnectionError, + "OK1", + ] + ) + mock_db2.client.execute_command = AsyncMock( + side_effect=["OK2", ConnectionError, ConnectionError, ConnectionError] + ) mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) type(mock_fs).database = mock_selector threshold = 3 @@ -162,7 +174,7 @@ async def test_execute_command_fallback_to_another_db_after_failure_detection( ) fd.set_command_executor(command_executor=executor) - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' - assert await executor.execute_command('SET', 'key', 'value') == 'OK2' - assert await executor.execute_command('SET', 'key', 'value') == 'OK1' - assert mock_selector.call_count == 3 \ No newline at end of file + assert await executor.execute_command("SET", "key", "value") == "OK1" + assert await executor.execute_command("SET", "key", "value") == "OK2" + assert await executor.execute_command("SET", "key", "value") == "OK1" + assert mock_selector.call_count == 3 diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py index 64760740a1..76ccd29a06 100644 --- a/tests/test_asyncio/test_multidb/test_config.py +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -1,11 +1,22 @@ from unittest.mock import Mock from redis.asyncio import ConnectionPool -from redis.asyncio.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_GRACE_PERIOD, \ - DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.config import ( + DatabaseConfig, + MultiDbConfig, + DEFAULT_GRACE_PERIOD, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, AsyncFailoverStrategy -from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper, AsyncFailureDetector +from redis.asyncio.multidb.failover import ( + WeightBasedFailoverStrategy, + AsyncFailoverStrategy, +) +from redis.asyncio.multidb.failure_detector import ( + FailureDetectorAsyncWrapper, + AsyncFailureDetector, +) from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck from redis.asyncio.retry import Retry from redis.multidb.circuit import CircuitBreaker @@ -14,14 +25,18 @@ class TestMultiDbConfig: def test_default_config(self): db_configs = [ - DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), - DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), - DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), - ] - - config = MultiDbConfig( - databases_config=db_configs - ) + DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ), + DatabaseConfig( + client_kwargs={"host": "host2", "port": "port2"}, weight=0.9 + ), + DatabaseConfig( + client_kwargs={"host": "host3", "port": "port3"}, weight=0.8 + ), + ] + + config = MultiDbConfig(databases_config=db_configs) assert config.databases_config == db_configs databases = config.databases() @@ -33,20 +48,28 @@ def test_default_config(self): assert weight == db_configs[i].weight assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD assert db.client.get_retry() is not config.command_retry - i+=1 + i += 1 assert len(config.default_failure_detectors()) == 1 - assert isinstance(config.default_failure_detectors()[0], FailureDetectorAsyncWrapper) + assert isinstance( + config.default_failure_detectors()[0], FailureDetectorAsyncWrapper + ) assert len(config.default_health_checks()) == 1 assert isinstance(config.default_health_checks()[0], EchoHealthCheck) assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL - assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert isinstance( + config.default_failover_strategy(), WeightBasedFailoverStrategy + ) assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL assert isinstance(config.command_retry, Retry) def test_overridden_config(self): grace_period = 2 - mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools = [ + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + ] mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} @@ -56,22 +79,31 @@ def test_overridden_config(self): mock_cb2.grace_period = grace_period mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period - mock_failure_detectors = [Mock(spec=AsyncFailureDetector), Mock(spec=AsyncFailureDetector)] + mock_failure_detectors = [ + Mock(spec=AsyncFailureDetector), + Mock(spec=AsyncFailureDetector), + ] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] health_check_interval = 10 mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) auto_fallback_interval = 10 db_configs = [ - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 - ), - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 - ), - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 - ), - ] + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, + weight=1.0, + circuit=mock_cb1, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, + weight=0.9, + circuit=mock_cb2, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, + weight=0.8, + circuit=mock_cb3, + ), + ] config = MultiDbConfig( databases_config=db_configs, @@ -92,7 +124,7 @@ def test_overridden_config(self): assert weight == db_configs[i].weight assert db.client.connection_pool == mock_connection_pools[i] assert db.circuit.grace_period == grace_period - i+=1 + i += 1 assert len(config.failure_detectors) == 2 assert config.failure_detectors[0] == mock_failure_detectors[0] @@ -104,11 +136,14 @@ def test_overridden_config(self): assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval + class TestDatabaseConfig: def test_default_config(self): - config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + config = DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ) - assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.client_kwargs == {"host": "host1", "port": "port1"} assert config.weight == 1.0 assert isinstance(config.default_circuit_breaker(), CircuitBreaker) @@ -117,9 +152,11 @@ def test_overridden_config(self): mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( - client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + client_kwargs={"connection_pool": mock_connection_pool}, + weight=1.0, + circuit=mock_circuit, ) - assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.client_kwargs == {"connection_pool": mock_connection_pool} assert config.weight == 1.0 - assert config.circuit == mock_circuit \ No newline at end of file + assert config.circuit == mock_circuit diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index 0275969d03..22d27f6369 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -4,27 +4,33 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) +from redis.asyncio.multidb.failover import ( + WeightBasedFailoverStrategy, + DefaultFailoverStrategyExecutor, +) class TestAsyncWeightBasedFailoverStrategy: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], - ids=['all closed - highest weight', 'highest weight - open'], + ids=["all closed - highest weight", "highest weight - open"], indirect=True, ) async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): @@ -40,43 +46,49 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) - async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + async def test_throws_exception_on_empty_databases( + self, mock_db, mock_db1, mock_db2 + ): failover_strategy = WeightBasedFailoverStrategy() - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + with pytest.raises( + NoValidDatabaseException, + match="No valid database available for communication", + ): assert await failover_strategy.database() + class TestDefaultStrategyExecutor: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_db', + "mock_db", [ - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, ], indirect=True, ) - async def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + async def test_execute_returns_valid_database_with_failover_attempts( + self, mock_db, mock_fs + ): failover_attempts = 3 mock_fs.database.side_effect = [ NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - mock_db + mock_db, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) for i in range(failover_attempts + 1): @@ -100,12 +112,10 @@ async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - NoValidDatabaseException + NoValidDatabaseException, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) with pytest.raises(NoValidDatabaseException): @@ -123,24 +133,27 @@ async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): assert mock_fs.database.call_count == 4 @pytest.mark.asyncio - async def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): + async def test_execute_throws_exception_on_attempts_does_not_exceed_delay( + self, mock_fs + ): failover_attempts = 3 mock_fs.database.side_effect = [ NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - NoValidDatabaseException + NoValidDatabaseException, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - with pytest.raises(TemporaryUnavailableException, match=( - "No database connections currently available. " - "This is a temporary condition - please retry the operation." - )): + with pytest.raises( + TemporaryUnavailableException, + match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ), + ): for i in range(failover_attempts + 1): try: await executor.execute() @@ -152,4 +165,4 @@ async def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, if i == failover_attempts: raise e - assert mock_fs.database.call_count == 4 \ No newline at end of file + assert mock_fs.database.call_count == 4 diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py index a4d7407609..0d18c9137c 100644 --- a/tests/test_asyncio/test_multidb/test_failure_detector.py +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -13,7 +13,7 @@ class TestFailureDetectorAsyncWrapper: @pytest.mark.asyncio @pytest.mark.parametrize( - 'min_num_failures,failure_rate_threshold,circuit_state', + "min_num_failures,failure_rate_threshold,circuit_state", [ (2, 0.4, CBState.OPEN), (2, 0, CBState.OPEN), @@ -30,32 +30,31 @@ class TestFailureDetectorAsyncWrapper: ], ) async def test_failure_detector_correctly_reacts_to_failures( - self, - min_num_failures, - failure_rate_threshold, - circuit_state + self, min_num_failures, failure_rate_threshold, circuit_state ): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(min_num_failures, failure_rate_threshold)) + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(min_num_failures, failure_rate_threshold) + ) mock_db = Mock(spec=Database) mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=AsyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) - await fd.register_command_execution(('GET', 'key')) - await fd.register_command_execution(('GET','key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) - await fd.register_command_execution(('GET', 'key')) - await fd.register_command_execution(('GET','key')) - await fd.register_command_execution(('GET','key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == circuit_state @pytest.mark.asyncio @pytest.mark.parametrize( - 'min_num_failures,failure_rate_threshold', + "min_num_failures,failure_rate_threshold", [ (3, 0.0), (3, 0.6), @@ -65,7 +64,9 @@ async def test_failure_detector_correctly_reacts_to_failures( "do not exceeds min num failures AND failure rate, during interval", ], ) - async def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min_num_failures, failure_rate_threshold): + async def test_failure_detector_do_not_open_circuit_on_interval_exceed( + self, min_num_failures, failure_rate_threshold + ): fd = FailureDetectorAsyncWrapper( CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) ) @@ -76,30 +77,34 @@ async def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - await fd.register_command_execution(('GET', 'key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) await asyncio.sleep(0.16) - await fd.register_command_execution(('GET', 'key')) - await fd.register_command_execution(('GET', 'key')) - await fd.register_command_execution(('GET', 'key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) await asyncio.sleep(0.16) - await fd.register_command_execution(('GET', 'key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == CBState.CLOSED # 2 more failure as last one already refreshed timer - await fd.register_command_execution(('GET', 'key')) - await fd.register_failure(Exception(), ('GET', 'key')) - await fd.register_command_execution(('GET', 'key')) - await fd.register_failure(Exception(), ('GET', 'key')) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == CBState.OPEN @pytest.mark.asyncio - async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self): - fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed( + self, + ): + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(5, 1, error_types=[ConnectionError]) + ) mock_db = Mock(spec=Database) mock_db.circuit.state = CBState.CLOSED mock_ce = Mock(spec=AsyncCommandExecutor) @@ -107,16 +112,16 @@ async def test_failure_detector_open_circuit_on_specific_exception_threshold_exc fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(Exception(), ("SET", "key1", "value1")) + await fd.register_failure(Exception(), ("SET", "key1", "value1")) assert mock_db.circuit.state == CBState.CLOSED - await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) - assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file + assert mock_db.circuit.state == CBState.OPEN diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index 72da0ef737..242446f3fb 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -2,8 +2,14 @@ from mock.mock import AsyncMock, Mock from redis.asyncio.multidb.database import Database -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ - HealthyMajorityPolicy, HealthyAnyPolicy +from redis.asyncio.multidb.healthcheck import ( + EchoHealthCheck, + LagAwareHealthCheck, + HealthCheck, + HealthyAllPolicy, + HealthyMajorityPolicy, + HealthyAnyPolicy, +) from redis.http.http_client import HttpError from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -46,11 +52,12 @@ async def test_policy_raise_unhealthy_database_exception(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): await policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 + class TestHealthyMajorityPolicy: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -68,20 +75,26 @@ class TestHealthyMajorityPolicy: (4, [False, True, True, True], [True, True, False, True], 4, 4, True), ], ids=[ - 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', - 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', - 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', - 'HC1 + HC2 - majority - even' - ] + "HC1 - no majority - odd", + "HC2 - no majority - odd", + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 + HC2 - majority - odd", + "HC1 - no majority - even", + "HC2 - no majority - even", + "HC1 - majority - even", + "HC2 - majority - even", + "HC1 + HC2 - majority - even", + ], ) async def test_policy_returns_true_for_majority_successful_probes( - self, - probes, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count, - expected_result + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -100,21 +113,30 @@ async def test_policy_returns_true_for_majority_successful_probes( [ (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), - (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), - (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ( + 4, + [True, ConnectionError, ConnectionError, True], + [True, True, True, True], + 3, + 0, + ), + ( + 4, + [True, True, True, True], + [True, ConnectionError, ConnectionError, False], + 4, + 3, + ), ], ids=[ - 'HC1 - majority- odd', 'HC2 - majority - odd', - 'HC1 - majority - even', 'HC2 - majority - even', - ] + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 - majority - even", + "HC2 - majority - even", + ], ) async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( - self, - probes, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count + self, probes, hc1_side_effect, hc2_side_effect, hc1_call_count, hc2_call_count ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -123,11 +145,12 @@ async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exce mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): await policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == hc1_call_count assert mock_hc2.check_health.call_count == hc2_call_count + class TestHealthyAnyPolicy: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -139,17 +162,19 @@ class TestHealthyAnyPolicy: ([True, True, True], [False, True, False], 1, 2, True), ], ids=[ - 'HC1 - no successful', 'HC2 - no successful', - 'HC1 - successful', 'HC2 - successful', - ] + "HC1 - no successful", + "HC2 - no successful", + "HC1 - successful", + "HC2 - successful", + ], ) async def test_policy_returns_true_for_any_successful_probe( - self, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count, - expected_result + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -163,7 +188,9 @@ async def test_policy_returns_true_for_any_successful_probe( assert mock_hc2.check_health.call_count == hc2_call_count @pytest.mark.asyncio - async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check( + self, + ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) mock_hc1.check_health.side_effect = [False, False, ConnectionError] @@ -171,20 +198,20 @@ async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_ mock_db = Mock(spec=Database) policy = HealthyAnyPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): await policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 -class TestEchoHealthCheck: +class TestEchoHealthCheck: @pytest.mark.asyncio async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) + mock_client.execute_command = AsyncMock(side_effect=["healthcheck"]) hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) @@ -192,12 +219,14 @@ async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio - async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + async def test_database_is_unhealthy_on_incorrect_echo_response( + self, mock_client, mock_cb + ): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command = AsyncMock(side_effect=['wrong']) + mock_client.execute_command = AsyncMock(side_effect=["wrong"]) hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) @@ -205,8 +234,10 @@ async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_clien assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio - async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command = AsyncMock(side_effect=['healthcheck']) + async def test_database_close_circuit_on_successful_healthcheck( + self, mock_client, mock_cb + ): + mock_client.execute_command = AsyncMock(side_effect=["healthcheck"]) mock_cb.state = CBState.HALF_OPEN hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) @@ -214,9 +245,12 @@ async def test_database_close_circuit_on_successful_healthcheck(self, mock_clien assert await hc.check_health(db) == True assert mock_client.execute_command.call_count == 1 + class TestLagAwareHealthCheck: @pytest.mark.asyncio - async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + async def test_database_is_healthy_when_bdb_matches_by_dns_name( + self, mock_client, mock_cb + ): """ Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name matches database host, and availability endpoint returns success. @@ -240,9 +274,7 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_clien None, ] - hc = LagAwareHealthCheck( - rest_api_port=1234, lag_aware_tolerance=150 - ) + hc = LagAwareHealthCheck(rest_api_port=1234, lag_aware_tolerance=150) # Inject our mocked http client hc._http_client = mock_http @@ -250,17 +282,24 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_clien assert await hc.check_health(db) is True # Base URL must be set correctly - assert hc._http_client.client.base_url == f"https://healthcheck.example.com:1234" + assert ( + hc._http_client.client.base_url == f"https://healthcheck.example.com:1234" + ) # Calls: first to list bdbs, then to availability assert mock_http.get.call_count == 2 first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert ( + second_call.args[0] + == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + ) assert second_call.kwargs.get("expect_json") is False @pytest.mark.asyncio - async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + async def test_database_is_healthy_when_bdb_matches_by_addr( + self, mock_client, mock_cb + ): """ Ensures health check succeeds when endpoint addr list contains the database host. """ @@ -287,7 +326,10 @@ async def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, m assert await hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + assert ( + mock_http.get.call_args_list[1].args[0] + == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + ) @pytest.mark.asyncio async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): @@ -300,8 +342,16 @@ async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_c mock_http = AsyncMock() # Return bdbs that do not match host by dns_name nor addr mock_http.get.return_value = [ - {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, - {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + { + "uid": "a", + "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}], + }, + { + "uid": "b", + "endpoints": [ + {"dns_name": "another.example.com", "addr": ["10.0.0.10"]} + ], + }, ] hc = LagAwareHealthCheck() @@ -328,7 +378,11 @@ async def test_propagates_http_error_from_availability(self, mock_client, mock_c mock_http.get.side_effect = [ [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], # Second: availability -> raise HttpError - HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + HttpError( + url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", + status=503, + message="busy", + ), ] hc = LagAwareHealthCheck() @@ -341,4 +395,4 @@ async def test_propagates_http_error_from_availability(self, mock_client, mock_c assert e.status == 503 # Ensure both calls were attempted - assert mock_http.get.call_count == 2 \ No newline at end of file + assert mock_http.get.call_count == 2 diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 492919cdac..da3a6a1737 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -19,29 +19,34 @@ def mock_pipe() -> Pipeline: mock_pipe.__aexit__ = AsyncMock(return_value=None) return mock_pipe + class TestPipeline: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_executes_pipeline_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): pipe = mock_pipe() - pipe.execute.return_value = ['OK1', 'value1'] + pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe mock_hc.check_health.return_value = True @@ -50,46 +55,58 @@ async def test_executes_pipeline_against_correct_db( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 pipe = client.pipeline() - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert await pipe.execute() == ['OK1', 'value1'] + assert await pipe.execute() == ["OK1", "value1"] assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) async def test_execute_pipeline_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): pipe = mock_pipe() - pipe.execute.return_value = ['OK1', 'value1'] + pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 async with client.pipeline() as pipe: - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert await pipe.execute() == ['OK1', 'value1'] + assert await pipe.execute() == ["OK1", "value1"] assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -98,19 +115,19 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -126,22 +143,43 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] pipe = mock_pipe() - pipe.execute.return_value = ['OK', 'value'] + pipe.execute.return_value = ["OK", "value"] mock_db.client.pipeline.return_value = pipe pipe1 = mock_pipe() - pipe1.execute.return_value = ['OK1', 'value'] + pipe1.execute.return_value = ["OK1", "value"] mock_db1.client.pipeline.return_value = pipe1 pipe2 = mock_pipe() - pipe2.execute.return_value = ['OK2', 'value'] + pipe2.execute.return_value = ["OK2", "value"] mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 @@ -150,57 +188,62 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de client = MultiDBClient(mock_multi_db_config) async with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert await pipe.execute() == ['OK1', 'value'] + assert await pipe.execute() == ["OK1", "value"] await asyncio.sleep(0.15) async with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert await pipe.execute() == ['OK2', 'value'] + assert await pipe.execute() == ["OK2", "value"] await asyncio.sleep(0.1) async with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert await pipe.execute() == ['OK', 'value'] + assert await pipe.execute() == ["OK", "value"] await asyncio.sleep(0.1) async with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") + + assert await pipe.execute() == ["OK1", "value"] - assert await pipe.execute() == ['OK1', 'value'] class TestTransaction: @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_executes_transaction_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.transaction.return_value = ['OK1', 'value1'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] mock_hc.check_health.return_value = True @@ -208,44 +251,56 @@ async def test_executes_transaction_against_correct_db( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 async def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert await client.transaction(callback) == ['OK1', 'value1'] + assert await client.transaction(callback) == ["OK1", "value1"] assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) async def test_execute_transaction_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.transaction.return_value = ['OK1', 'value1'] - - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 async def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert await client.transaction(callback) == ['OK1', 'value1'] + assert await client.transaction(callback) == ["OK1", "value1"] assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -254,19 +309,19 @@ async def callback(pipe: Pipeline): @pytest.mark.asyncio @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -282,15 +337,36 @@ async def test_execute_transaction_against_correct_db_on_background_health_check databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] - - mock_db.client.transaction.return_value = ['OK', 'value'] - mock_db1.client.transaction.return_value = ['OK1', 'value'] - mock_db2.client.transaction.return_value = ['OK2', 'value'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + mock_db.client.transaction.return_value = ["OK", "value"] + mock_db1.client.transaction.return_value = ["OK1", "value"] + mock_db2.client.transaction.return_value = ["OK2", "value"] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() @@ -298,13 +374,13 @@ async def test_execute_transaction_against_correct_db_on_background_health_check client = MultiDBClient(mock_multi_db_config) async def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert await client.transaction(callback) == ['OK1', 'value'] + assert await client.transaction(callback) == ["OK1", "value"] await asyncio.sleep(0.15) - assert await client.transaction(callback) == ['OK2', 'value'] + assert await client.transaction(callback) == ["OK2", "value"] await asyncio.sleep(0.1) - assert await client.transaction(callback) == ['OK', 'value'] + assert await client.transaction(callback) == ["OK", "value"] await asyncio.sleep(0.1) - assert await client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file + assert await client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 88313afdd6..67c9c829c3 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -7,8 +7,11 @@ from redis.asyncio import Redis, RedisCluster from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.config import DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ - MultiDbConfig +from redis.asyncio.multidb.config import ( + DEFAULT_HEALTH_CHECK_INTERVAL, + DatabaseConfig, + MultiDbConfig, +) from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff @@ -25,83 +28,90 @@ def __init__(self): async def listen(self, event: AsyncActiveDatabaseChanged): self.is_changed_flag = True + @pytest.fixture() def fault_injector_client(): - url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") - return FaultInjectorClient(url) + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + @pytest_asyncio.fixture() -async def r_multi_db(request) -> AsyncGenerator[tuple[MultiDBClient, CheckActiveDatabaseChangedListener, Any], Any]: - client_class = request.param.get('client_class', Redis) - - if client_class == Redis: - endpoint_config = get_endpoints_config('re-active-active') - else: - endpoint_config = get_endpoints_config('re-active-active-oss-cluster') - - username = endpoint_config.get('username', None) - password = endpoint_config.get('password', None) - min_num_failures = request.param.get('min_num_failures', DEFAULT_MIN_NUM_FAILURES) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) - - # Retry configuration different for health checks as initial health check require more time in case - # if infrastructure wasn't restored from the previous test. - health_check_interval = request.param.get('health_check_interval', 10) - health_checks = request.param.get('health_checks', []) - event_dispatcher = EventDispatcher() - listener = CheckActiveDatabaseChangedListener() - event_dispatcher.register_listeners({ - AsyncActiveDatabaseChanged: [listener], - }) - db_configs = [] - - db_config = DatabaseConfig( - weight=1.0, - from_url=endpoint_config['endpoints'][0], - client_kwargs={ - 'username': username, - 'password': password, - 'decode_responses': True, - }, - health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) - ) - db_configs.append(db_config) - - db_config1 = DatabaseConfig( - weight=0.9, - from_url=endpoint_config['endpoints'][1], - client_kwargs={ - 'username': username, - 'password': password, - 'decode_responses': True, - }, - health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) - ) - db_configs.append(db_config1) - - config = MultiDbConfig( - client_class=client_class, - databases_config=db_configs, - command_retry=command_retry, - min_num_failures=min_num_failures, - health_checks=health_checks, - health_check_probes=3, - health_check_interval=health_check_interval, - event_dispatcher=event_dispatcher, - ) - - client = MultiDBClient(config) - - async def teardown(): - await client.aclose() - - if ( - client.command_executor.active_database - and isinstance(client.command_executor.active_database.client, Redis) - ): +async def r_multi_db( + request, +) -> AsyncGenerator[tuple[MultiDBClient, CheckActiveDatabaseChangedListener, Any], Any]: + client_class = request.param.get("client_class", Redis) + + if client_class == Redis: + endpoint_config = get_endpoints_config("re-active-active") + else: + endpoint_config = get_endpoints_config("re-active-active-oss-cluster") + + username = endpoint_config.get("username", None) + password = endpoint_config.get("password", None) + min_num_failures = request.param.get("min_num_failures", DEFAULT_MIN_NUM_FAILURES) + command_retry = request.param.get( + "command_retry", Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10) + ) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get("health_check_interval", 10) + health_checks = request.param.get("health_checks", []) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners( + { + AsyncActiveDatabaseChanged: [listener], + } + ) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config["endpoints"][0], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][0]), + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config["endpoints"][1], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][1]), + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + min_num_failures=min_num_failures, + health_checks=health_checks, + health_check_probes=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + ) + + client = MultiDBClient(config) + + async def teardown(): + await client.aclose() + + if client.command_executor.active_database and isinstance( + client.command_executor.active_database.client, Redis + ): await client.command_executor.active_database.client.connection_pool.disconnect() - await asyncio.sleep(10) + await asyncio.sleep(10) - yield client, listener, endpoint_config - await teardown() \ No newline at end of file + yield client, listener, endpoint_config + await teardown() diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 55b604528b..208084daf9 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -9,7 +9,10 @@ from redis.asyncio import RedisCluster from redis.asyncio.client import Pipeline, Redis from redis.asyncio.multidb.client import MultiDBClient -from redis.asyncio.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.asyncio.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, +) from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck from redis.asyncio.retry import Retry from redis.backoff import ConstantBackoff @@ -19,25 +22,31 @@ logger = logging.getLogger(__name__) -async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): + +async def trigger_network_failure_action( + fault_injector_client, config, event: asyncio.Event = None +): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": config['bdb_id'], "delay": 3, "cluster_index": 0} + parameters={"bdb_id": config["bdb_id"], "delay": 3, "cluster_index": 0}, ) result = fault_injector_client.trigger_action(action_request) - status_result = fault_injector_client.get_action_status(result['action_id']) + status_result = fault_injector_client.get_action_status(result["action_id"]) - while status_result['status'] != "success": + while status_result["status"] != "success": await asyncio.sleep(0.1) - status_result = fault_injector_client.get_action_status(result['action_id']) - logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + status_result = fault_injector_client.get_action_status(result["action_id"]) + logger.info( + f"Waiting for action to complete. Status: {status_result['status']}" + ) if event: event.set() logger.info(f"Action completed. Status: {status_result['status']}") + class TestActiveActive: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -47,102 +56,130 @@ class TestActiveActive: {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(200) - async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + async def test_multi_db_client_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db # Handle unavailable databases from previous test. retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) await retry.call_with_retry( - lambda : r_multi_db.set('key', 'value'), - lambda _: dummy_fail_async() + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail_async() ) # Execute commands before network failure while not event.is_set(): - assert await retry.call_with_retry( - lambda: r_multi_db.get('key') , - lambda _: dummy_fail_async() - ) == 'value' + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert await retry.call_with_retry( - lambda: r_multi_db.get('key'), - lambda _: dummy_fail_async() - ) == 'value' + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) await asyncio.sleep(0.5) @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", [ - {"client_class": Redis, "min_num_failures": 2, "health_checks": - [ + { + "client_class": Redis, + "min_num_failures": 2, + "health_checks": [ LagAwareHealthCheck( verify_tls=False, - auth_basic=(os.getenv('ENV0_USERNAME'),os.getenv('ENV0_PASSWORD')), + auth_basic=( + os.getenv("ENV0_USERNAME"), + os.getenv("ENV0_PASSWORD"), + ), ) ], - "health_check_interval": 20, + "health_check_interval": 20, }, - {"client_class": RedisCluster, "min_num_failures": 2, "health_checks": - [ + { + "client_class": RedisCluster, + "min_num_failures": 2, + "health_checks": [ LagAwareHealthCheck( verify_tls=False, - auth_basic=(os.getenv('ENV0_USERNAME'), os.getenv('ENV0_PASSWORD')), + auth_basic=( + os.getenv("ENV0_USERNAME"), + os.getenv("ENV0_PASSWORD"), + ), ) ], - "health_check_interval": 20, + "health_check_interval": 20, }, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(200) - async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + async def test_multi_db_client_uses_lag_aware_health_check( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) await retry.call_with_retry( - lambda: r_multi_db.set('key', 'value'), - lambda _: dummy_fail_async() + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail_async() ) # Execute commands before network failure while not event.is_set(): - assert await retry.call_with_retry( - lambda: r_multi_db.get('key'), - lambda _: dummy_fail_async() - ) == 'value' + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) await asyncio.sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert await retry.call_with_retry( - lambda: r_multi_db.get('key'), - lambda _: dummy_fail_async() - ) == 'value' + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) await asyncio.sleep(0.5) @pytest.mark.asyncio @@ -153,44 +190,55 @@ async def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fau {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(200) - async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + async def test_context_manager_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) async def callback(): async with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert await pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) # Execute pipeline before network failure while not event.is_set(): await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() + lambda: callback(), lambda _: dummy_fail_async() ) await asyncio.sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() + lambda: callback(), lambda _: dummy_fail_async() ) await asyncio.sleep(0.5) @@ -202,44 +250,55 @@ async def callback(): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(200) - async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + async def test_chaining_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) async def callback(): pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert await pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) # Execute pipeline before network failure while not event.is_set(): await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() + lambda: callback(), lambda _: dummy_fail_async() ) await asyncio.sleep(0.5) # Execute pipeline until database failover while not listener.is_changed_flag: await retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail_async() + lambda: callback(), lambda _: dummy_fail_async() ) await asyncio.sleep(0.5) @@ -251,35 +310,41 @@ async def callback(): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(200) - async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + async def test_transaction_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) async def callback(pipe: Pipeline): - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) # Execute transaction before network failure while not event.is_set(): await retry.call_with_retry( lambda: r_multi_db.transaction(callback), - lambda _: dummy_fail_async() + lambda _: dummy_fail_async(), ) await asyncio.sleep(0.5) @@ -287,26 +352,24 @@ async def callback(pipe: Pipeline): while not listener.is_changed_flag: assert await retry.call_with_retry( lambda: r_multi_db.transaction(callback), - lambda _: dummy_fail_async() - ) == [True, True, True, 'value1', 'value2', 'value3'] + lambda _: dummy_fail_async(), + ) == [True, True, True, "value1", "value2", "value3"] await asyncio.sleep(0.5) @pytest.mark.asyncio - @pytest.mark.parametrize( - "r_multi_db", - [{"min_num_failures": 2}], - indirect=True - ) + @pytest.mark.parametrize("r_multi_db", [{"min_num_failures": 2}], indirect=True) @pytest.mark.timeout(200) - async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + async def test_pubsub_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): client, listener, endpoint_config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) - data = json.dumps({'message': 'test'}) + data = json.dumps({"message": "test"}) messages_count = 0 async def handler(message): @@ -315,41 +378,45 @@ async def handler(message): async with client as r_multi_db: event = asyncio.Event() - asyncio.create_task(trigger_network_failure_action(fault_injector_client, endpoint_config, event)) + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) pubsub = await r_multi_db.pubsub() # Assign a handler and run in a separate thread. await retry.call_with_retry( - lambda: pubsub.subscribe(**{'test-channel': handler}), - lambda _: dummy_fail_async() + lambda: pubsub.subscribe(**{"test-channel": handler}), + lambda _: dummy_fail_async(), ) task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) # Execute publish before network failure while not event.is_set(): await retry.call_with_retry( - lambda: r_multi_db.publish('test-channel', data), - lambda _: dummy_fail_async() + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), ) await asyncio.sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: await retry.call_with_retry( - lambda: r_multi_db.publish('test-channel', data), - lambda _: dummy_fail_async() + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), ) await asyncio.sleep(0.5) # After db changed still generates some traffic. for _ in range(5): await retry.call_with_retry( - lambda: r_multi_db.publish('test-channel', data), - lambda _: dummy_fail_async() + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), ) # A timeout to ensure that an async handler will handle all previous messages. await asyncio.sleep(0.1) task.cancel() - assert messages_count >= 2 \ No newline at end of file + assert messages_count >= 2 diff --git a/tests/test_background.py b/tests/test_background.py index ba62e5bdd9..bac9c1eef6 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -5,10 +5,11 @@ from redis.background import BackgroundScheduler + class TestBackgroundScheduler: def test_run_once(self): execute_counter = 0 - one = 'arg1' + one = "arg1" two = 9999 def callback(arg1: str, arg2: int): @@ -35,11 +36,11 @@ def callback(arg1: str, arg2: int): (0.012, 0.04, 3), (0.035, 0.04, 1), (0.045, 0.04, 0), - ] + ], ) def test_run_recurring(self, interval, timeout, call_count): execute_counter = 0 - one = 'arg1' + one = "arg1" two = 9999 def callback(arg1: str, arg2: int): @@ -67,11 +68,11 @@ def callback(arg1: str, arg2: int): (0.012, 0.04, 3), (0.035, 0.04, 1), (0.045, 0.04, 0), - ] + ], ) async def test_run_recurring_async(self, interval, timeout, call_count): execute_counter = 0 - one = 'arg1' + one = "arg1" two = 9999 async def callback(arg1: str, arg2: int): @@ -90,4 +91,4 @@ async def callback(arg1: str, arg2: int): await asyncio.sleep(timeout) - assert execute_counter == call_count \ No newline at end of file + assert execute_counter == call_count diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py index 31ac5c4316..dd120d94d7 100644 --- a/tests/test_data_structure.py +++ b/tests/test_data_structure.py @@ -10,46 +10,61 @@ class TestWeightedList: def test_add_items(self): wlist = WeightedList() - wlist.add('item1', 3.0) - wlist.add('item2', 2.0) - wlist.add('item3', 4.0) - wlist.add('item4', 4.0) - - assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.get_top_n(4) == [ + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ("item2", 2.0), + ] def test_remove_items(self): wlist = WeightedList() - wlist.add('item1', 3.0) - wlist.add('item2', 2.0) - wlist.add('item3', 4.0) - wlist.add('item4', 4.0) + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) - assert wlist.remove('item2') == 2.0 - assert wlist.remove('item4') == 4.0 + assert wlist.remove("item2") == 2.0 + assert wlist.remove("item4") == 4.0 - assert wlist.get_top_n(4) == [('item3', 4.0), ('item1', 3.0)] + assert wlist.get_top_n(4) == [("item3", 4.0), ("item1", 3.0)] def test_get_by_weight_range(self): wlist = WeightedList() - wlist.add('item1', 3.0) - wlist.add('item2', 2.0) - wlist.add('item3', 4.0) - wlist.add('item4', 4.0) + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) - assert wlist.get_by_weight_range(2.0, 3.0) == [('item1', 3.0), ('item2', 2.0)] + assert wlist.get_by_weight_range(2.0, 3.0) == [("item1", 3.0), ("item2", 2.0)] def test_update_weights(self): wlist = WeightedList() - wlist.add('item1', 3.0) - wlist.add('item2', 2.0) - wlist.add('item3', 4.0) - wlist.add('item4', 4.0) - - assert wlist.get_top_n(4) == [('item3', 4.0), ('item4', 4.0), ('item1', 3.0), ('item2', 2.0)] - - wlist.update_weight('item2', 5.0) - - assert wlist.get_top_n(4) == [('item2', 5.0), ('item3', 4.0), ('item4', 4.0), ('item1', 3.0)] + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.get_top_n(4) == [ + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ("item2", 2.0), + ] + + wlist.update_weight("item2", 5.0) + + assert wlist.get_top_n(4) == [ + ("item2", 5.0), + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ] def test_thread_safety(self) -> None: """Test thread safety with concurrent operations""" @@ -76,4 +91,4 @@ def worker(worker_id): futures = [executor.submit(worker, i) for i in range(5)] concurrent.futures.wait(futures) - assert len(wl) == 500 \ No newline at end of file + assert len(wl) == 500 diff --git a/tests/test_event.py b/tests/test_event.py index 27526abeaf..f090251295 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,6 +1,10 @@ from unittest.mock import Mock, AsyncMock -from redis.event import EventListenerInterface, EventDispatcher, AsyncEventListenerInterface +from redis.event import ( + EventListenerInterface, + EventDispatcher, + AsyncEventListenerInterface, +) class TestEventDispatcher: @@ -16,7 +20,9 @@ def callback(event): mock_event_listener.listen = callback # Register via constructor - dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + dispatcher = EventDispatcher( + event_listeners={type(mock_event): [mock_event_listener]} + ) dispatcher.dispatch(mock_event) assert listener_called == 1 @@ -24,7 +30,9 @@ def callback(event): # Register additional listener for the same event mock_another_event_listener = Mock(spec=EventListenerInterface) mock_another_event_listener.listen = callback - dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + dispatcher.register_listeners( + event_listeners={type(mock_event): [mock_another_event_listener]} + ) dispatcher.dispatch(mock_event) assert listener_called == 3 @@ -41,7 +49,9 @@ async def callback(event): mock_event_listener.listen = callback # Register via constructor - dispatcher = EventDispatcher(event_listeners={type(mock_event): [mock_event_listener]}) + dispatcher = EventDispatcher( + event_listeners={type(mock_event): [mock_event_listener]} + ) await dispatcher.dispatch_async(mock_event) assert listener_called == 1 @@ -49,7 +59,9 @@ async def callback(event): # Register additional listener for the same event mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) mock_another_event_listener.listen = callback - dispatcher.register_listeners(event_listeners={type(mock_event): [mock_another_event_listener]}) + dispatcher.register_listeners( + event_listeners={type(mock_event): [mock_another_event_listener]} + ) await dispatcher.dispatch_async(mock_event) - assert listener_called == 3 \ No newline at end of file + assert listener_called == 3 diff --git a/tests/test_http/test_http_client.py b/tests/test_http/test_http_client.py index 9a6d28ecd4..5dc1cf1631 100644 --- a/tests/test_http/test_http_client.py +++ b/tests/test_http/test_http_client.py @@ -13,7 +13,9 @@ class FakeResponse: - def __init__(self, *, status: int, headers: Dict[str, str], url: str, content: bytes): + def __init__( + self, *, status: int, headers: Dict[str, str], url: str, content: bytes + ): self.status = status self.headers = headers self._url = url @@ -32,8 +34,11 @@ def __enter__(self) -> "FakeResponse": def __exit__(self, exc_type, exc, tb) -> None: return None + class TestHttpClient: - def test_get_returns_parsed_json_and_uses_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_get_returns_parsed_json_and_uses_timeout( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "v1/items" @@ -65,7 +70,9 @@ def fake_urlopen(request, *, timeout=None, context=None): client = HttpClient(base_url=base_url) # Act - result = client.get(path, params=params, timeout=12.34) # default expect_json=True + result = client.get( + path, params=params, timeout=12.34 + ) # default expect_json=True # Assert assert result == payload @@ -104,10 +111,13 @@ def fake_urlopen(request, *, timeout=None, context=None): # Assert assert result == payload - def test_get_retries_on_retryable_http_errors_and_succeeds(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_get_retries_on_retryable_http_errors_and_succeeds( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange: configure limited retries so we can assert attempts - retry_policy = Retry(backoff=ExponentialWithJitterBackoff(base=0, cap=0), - retries=2) # 2 retries -> up to 3 attempts + retry_policy = Retry( + backoff=ExponentialWithJitterBackoff(base=0, cap=0), retries=2 + ) # 2 retries -> up to 3 attempts base_url = "https://api.example.com/" path = "sometimes-busy" expected_url = f"{base_url}{path}" @@ -119,7 +129,13 @@ def test_get_retries_on_retryable_http_errors_and_succeeds(self, monkeypatch: py def make_http_error(url: str, code: int, body: bytes = b"busy"): # Provide a file-like object for .read() when HttpClient tries to read error content fp = BytesIO(body) - return HTTPError(url=url, code=code, msg="Service Unavailable", hdrs={"Content-Type": "text/plain"}, fp=fp) + return HTTPError( + url=url, + code=code, + msg="Service Unavailable", + hdrs={"Content-Type": "text/plain"}, + fp=fp, + ) def flaky_urlopen(request, *, timeout=None, context=None): call_count["n"] += 1 @@ -144,7 +160,9 @@ def flaky_urlopen(request, *, timeout=None, context=None): assert result == payload assert call_count["n"] == retry_policy.get_retries() + 1 - def test_post_sends_json_body_and_parses_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_post_sends_json_body_and_parses_response( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "v1/create" @@ -158,9 +176,13 @@ def fake_urlopen(request, *, timeout=None, context=None): assert getattr(request, "method", "").upper() == "POST" assert request.full_url == expected_url # Content-Type should be auto-set for string JSON body - assert request.headers.get("Content-type") == "application/json; charset=utf-8" + assert ( + request.headers.get("Content-type") == "application/json; charset=utf-8" + ) # Body should be already UTF-8 encoded JSON with no spaces - assert request.data == json.dumps(send_payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + assert request.data == json.dumps( + send_payload, ensure_ascii=False, separators=(",", ":") + ).encode("utf-8") return FakeResponse( status=200, headers={"Content-Type": "application/json; charset=utf-8"}, @@ -178,7 +200,9 @@ def fake_urlopen(request, *, timeout=None, context=None): # Assert assert result == recv_payload - def test_post_with_raw_data_and_custom_headers(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_post_with_raw_data_and_custom_headers( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "upload" @@ -210,7 +234,9 @@ def fake_urlopen(request, *, timeout=None, context=None): # Assert assert result == recv_payload - def test_delete_returns_http_response_when_expect_json_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_delete_returns_http_response_when_expect_json_false( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "v1/resource/42" @@ -238,7 +264,9 @@ def fake_urlopen(request, *, timeout=None, context=None): assert resp.url == expected_url assert resp.content == body - def test_put_raises_http_error_on_non_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_put_raises_http_error_on_non_success( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "v1/update/1" @@ -246,7 +274,13 @@ def test_put_raises_http_error_on_non_success(self, monkeypatch: pytest.MonkeyPa def make_http_error(url: str, code: int, body: bytes = b"not found"): fp = BytesIO(body) - return HTTPError(url=url, code=code, msg="Not Found", hdrs={"Content-Type": "text/plain"}, fp=fp) + return HTTPError( + url=url, + code=code, + msg="Not Found", + hdrs={"Content-Type": "text/plain"}, + fp=fp, + ) def fake_urlopen(request, *, timeout=None, context=None): raise make_http_error(expected_url, 404) @@ -260,7 +294,9 @@ def fake_urlopen(request, *, timeout=None, context=None): assert exc.value.status == 404 assert exc.value.url == expected_url - def test_patch_with_params_encodes_query(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_patch_with_params_encodes_query( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange base_url = "https://api.example.com/" path = "v1/edit" @@ -288,11 +324,18 @@ def fake_urlopen(request, *, timeout=None, context=None): assert qs["q"] == ["hello world"] assert qs["tag"] == ["a", "b"] - def test_request_low_level_headers_auth_and_timeout_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_request_low_level_headers_auth_and_timeout_default( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: # Arrange: use plain HTTP to verify no TLS context, and check default timeout used base_url = "http://example.com/" path = "ping" - captured = {"timeout": None, "context": "unset", "headers": None, "method": None} + captured = { + "timeout": None, + "context": "unset", + "headers": None, + "method": None, + } def fake_urlopen(request, *, timeout=None, context=None): captured["timeout"] = timeout @@ -315,10 +358,14 @@ def fake_urlopen(request, *, timeout=None, context=None): assert resp.status == 200 assert captured["method"] == "GET" assert captured["context"] is None # no TLS for http - assert pytest.approx(captured["timeout"], rel=1e-6) == client.timeout # default used + assert ( + pytest.approx(captured["timeout"], rel=1e-6) == client.timeout + ) # default used # Check some default headers and Authorization presence headers = {k.lower(): v for k, v in captured["headers"].items()} - assert "authorization" in headers and headers["authorization"].startswith("Basic ") + assert "authorization" in headers and headers["authorization"].startswith( + "Basic " + ) assert headers.get("accept") == "application/json" assert "gzip" in headers.get("accept-encoding", "").lower() - assert "user-agent" in headers \ No newline at end of file + assert "user-agent" in headers diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index ce4658868f..27fa5475cd 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -5,12 +5,20 @@ from redis import Redis, ConnectionPool from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.config import ( + MultiDbConfig, + DatabaseConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_POLICY +from redis.multidb.healthcheck import ( + HealthCheck, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_POLICY, +) from tests.conftest import mock_ed @@ -18,94 +26,107 @@ def mock_client() -> Redis: return Mock(spec=Redis) + @pytest.fixture() def mock_cb() -> CircuitBreaker: return Mock(spec=CircuitBreaker) + @pytest.fixture() def mock_fd() -> FailureDetector: - return Mock(spec=FailureDetector) + return Mock(spec=FailureDetector) + @pytest.fixture() def mock_fs() -> FailoverStrategy: - return Mock(spec=FailoverStrategy) + return Mock(spec=FailoverStrategy) + @pytest.fixture() def mock_hc() -> HealthCheck: - return Mock(spec=HealthCheck) + return Mock(spec=HealthCheck) + @pytest.fixture() def mock_db(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() def mock_db1(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() def mock_db2(request) -> Database: - db = Mock(spec=Database) - db.weight = request.param.get("weight", 1.0) - db.client = Mock(spec=Redis) - db.client.connection_pool = Mock(spec=ConnectionPool) + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) - cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) - mock_cb.grace_period = cb.get("grace_period", 1.0) - mock_cb.state = cb.get("state", CBState.CLOSED) + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db - db.circuit = mock_cb - return db @pytest.fixture() -def mock_multi_db_config( - request, mock_fd, mock_fs, mock_hc, mock_ed -) -> MultiDbConfig: - hc_interval = request.param.get('hc_interval', DEFAULT_HEALTH_CHECK_INTERVAL) - auto_fallback_interval = request.param.get('auto_fallback_interval', DEFAULT_AUTO_FALLBACK_INTERVAL) - health_check_policy = request.param.get('health_check_policy', DEFAULT_HEALTH_CHECK_POLICY) - health_check_probes = request.param.get('health_check_probes', DEFAULT_HEALTH_CHECK_PROBES) - - config = MultiDbConfig( - databases_config=[Mock(spec=DatabaseConfig)], - failure_detectors=[mock_fd], - health_check_interval=hc_interval, - health_check_delay=0.05, - health_check_policy=health_check_policy, - health_check_probes=health_check_probes, - failover_strategy=mock_fs, - auto_fallback_interval=auto_fallback_interval, - event_dispatcher=mock_ed - ) - - return config +def mock_multi_db_config(request, mock_fd, mock_fs, mock_hc, mock_ed) -> MultiDbConfig: + hc_interval = request.param.get("hc_interval", DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get( + "auto_fallback_interval", DEFAULT_AUTO_FALLBACK_INTERVAL + ) + health_check_policy = request.param.get( + "health_check_policy", DEFAULT_HEALTH_CHECK_POLICY + ) + health_check_probes = request.param.get( + "health_check_probes", DEFAULT_HEALTH_CHECK_PROBES + ) + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed, + ) + + return config + def create_weighted_list(*databases) -> Databases: - dbs = WeightedList() + dbs = WeightedList() - for db in databases: - dbs.add(db, db.weight) + for db in databases: + dbs.add(db, db.weight) - return dbs \ No newline at end of file + return dbs diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 7dc642373b..9bf221ec52 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,14 +1,18 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker +from redis.multidb.circuit import ( + PBCircuitBreakerAdapter, + State as CbState, + CircuitBreaker, +) class TestPBCircuitBreaker: @pytest.mark.parametrize( - 'mock_db', + "mock_db", [ - {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CbState.CLOSED}}, ], indirect=True, ) @@ -49,4 +53,4 @@ def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): adapter.on_state_changed(callback) adapter.state = CbState.HALF_OPEN - assert called_count == 1 \ No newline at end of file + assert called_count == 1 diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 5e710f23c2..dab80f2ba4 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -17,31 +17,35 @@ class TestMultiDbClient: @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_against_correct_db_on_successful_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 assert mock_db.circuit.state == CBState.CLOSED @@ -49,31 +53,43 @@ def test_execute_command_against_correct_db_on_successful_initialization( assert mock_db2.circuit.state == CBState.CLOSED @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) def test_execute_command_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -81,19 +97,19 @@ def test_execute_command_against_correct_db_and_closed_circuit( assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -109,244 +125,321 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "OK", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "error", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "error", + "error", + ] mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" sleep(0.3) - assert client.set('key', 'value') == 'OK2' + assert client.set("key", "value") == "OK2" sleep(0.2) - assert client.set('key', 'value') == 'OK' + assert client.set("key", "value") == "OK" sleep(0.2) - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_auto_fallback_to_highest_weight_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "healthcheck", + "healthcheck", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "healthcheck", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "healthcheck", + "healthcheck", + "healthcheck", + ] mock_multi_db_config.health_check_interval = 0.2 mock_multi_db_config.auto_fallback_interval = 0.4 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" sleep(0.30) - assert client.set('key', 'value') == 'OK2' + assert client.set("key", "value") == "OK2" sleep(0.44) - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) def test_execute_command_throws_exception_on_failed_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): - client.set('key', 'value') + with pytest.raises( + NoValidDatabaseException, + match="Initial connection failed - no active database found", + ): + client.set("key", "value") assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_add_database_throws_exception_on_same_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - with pytest.raises(ValueError, match='Given database already exists'): + with pytest.raises(ValueError, match="Given database already exists"): client.add_database(mock_db) assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_add_database_makes_new_database_active( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK2' + assert client.set("key", "value") == "OK2" assert mock_hc.check_health.call_count == 6 client.add_database(mock_db1) assert mock_hc.check_health.call_count == 9 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_remove_highest_weighted_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 client.remove_database(mock_db1) - assert client.set('key', 'value') == 'OK2' + assert client.set("key", "value") == "OK2" @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_update_database_weight_to_be_highest( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 - assert client.set('key', 'value') == 'OK2' + assert client.set("key", "value") == "OK2" @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_add_new_failure_detector( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" mock_multi_db_config.event_dispatcher = EventDispatcher() mock_fd = mock_multi_db_config.failure_detectors[0] # Event fired if command against mock_db1 would fail command_fail_event = OnCommandsFailEvent( - commands=('SET', 'key', 'value'), + commands=("SET", "key", "value"), exception=Exception(), ) @@ -354,7 +447,7 @@ def test_add_new_failure_detector( client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 # Simulate failing command events that lead to a failure detection @@ -374,31 +467,35 @@ def test_add_new_failure_detector( assert another_fd.register_failure.call_count == 5 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_add_new_health_check( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 another_hc = Mock(spec=HealthCheck) @@ -411,41 +508,50 @@ def test_add_new_health_check( assert another_hc.check_health.call_count == 3 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_set_active_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db.client.execute_command.return_value = 'OK' + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db.client.execute_command.return_value = "OK" mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set('key', 'value') == 'OK1' + assert client.set("key", "value") == "OK1" assert mock_hc.check_health.call_count == 9 client.set_active_database(mock_db) - assert client.set('key', 'value') == 'OK' + assert client.set("key", "value") == "OK" - with pytest.raises(ValueError, match='Given database is not a member of database list'): + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): client.set_active_database(Mock(spec=SyncDatabase)) mock_hc.check_health.return_value = False - with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): - client.set_active_database(mock_db1) \ No newline at end of file + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + client.set_active_database(mock_db1) diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 2001d64f04..c27802cf09 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -15,19 +15,21 @@ class TestDefaultCommandExecutor: @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) - def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + def test_execute_command_on_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" databases = create_weighted_list(mock_db, mock_db1, mock_db2) executor = DefaultCommandExecutor( @@ -35,33 +37,33 @@ def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, m databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) executor.active_database = mock_db1 - assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert executor.execute_command("SET", "key", "value") == "OK1" executor.active_database = mock_db2 - assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert executor.execute_command("SET", "key", "value") == "OK2" assert mock_ed.register_listeners.call_count == 1 assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_automatically_select_active_database( - self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed ): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_fs.database.side_effect = [mock_db1, mock_db2] databases = create_weighted_list(mock_db, mock_db1, mock_db2) @@ -70,33 +72,33 @@ def test_execute_command_automatically_select_active_database( databases=databases, failover_strategy=mock_fs, event_dispatcher=mock_ed, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) - assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert executor.execute_command("SET", "key", "value") == "OK1" mock_db1.circuit.state = CBState.OPEN - assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert executor.execute_command("SET", "key", "value") == "OK2" assert mock_ed.register_listeners.call_count == 1 assert mock_fs.database.call_count == 2 assert mock_fd.register_command_execution.call_count == 2 @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_fallback_to_another_db_after_fallback_interval( - self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed ): - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] databases = create_weighted_list(mock_db, mock_db1, mock_db2) @@ -106,38 +108,49 @@ def test_execute_command_fallback_to_another_db_after_fallback_interval( failover_strategy=mock_fs, event_dispatcher=mock_ed, auto_fallback_interval=0.1, - command_retry=Retry(NoBackoff(), 0) + command_retry=Retry(NoBackoff(), 0), ) - assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert executor.execute_command("SET", "key", "value") == "OK1" mock_db1.weight = 0.1 sleep(0.15) - assert executor.execute_command('SET', 'key', 'value') == 'OK2' + assert executor.execute_command("SET", "key", "value") == "OK2" mock_db1.weight = 0.7 sleep(0.15) - assert executor.execute_command('SET', 'key', 'value') == 'OK1' + assert executor.execute_command("SET", "key", "value") == "OK1" assert mock_ed.register_listeners.call_count == 1 assert mock_fs.database.call_count == 3 assert mock_fd.register_command_execution.call_count == 3 @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_command_fallback_to_another_db_after_failure_detection( - self, mock_db, mock_db1, mock_db2, mock_fs + self, mock_db, mock_db1, mock_db2, mock_fs ): - mock_db1.client.execute_command.side_effect = ['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1'] - mock_db2.client.execute_command.side_effect = ['OK2', ConnectionError, ConnectionError, ConnectionError] + mock_db1.client.execute_command.side_effect = [ + "OK1", + ConnectionError, + ConnectionError, + ConnectionError, + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "OK2", + ConnectionError, + ConnectionError, + ConnectionError, + ] mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] threshold = 3 fd = CommandFailureDetector(threshold, 0.0, 1) @@ -154,7 +167,7 @@ def test_execute_command_fallback_to_another_db_after_failure_detection( ) fd.set_command_executor(command_executor=executor) - assert executor.execute_command('SET', 'key', 'value') == 'OK1' - assert executor.execute_command('SET', 'key', 'value') == 'OK2' - assert executor.execute_command('SET', 'key', 'value') == 'OK1' - assert mock_fs.database.call_count == 3 \ No newline at end of file + assert executor.execute_command("SET", "key", "value") == "OK1" + assert executor.execute_command("SET", "key", "value") == "OK2" + assert executor.execute_command("SET", "key", "value") == "OK1" + assert mock_fs.database.call_count == 3 diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index abed8ec2fa..a63ac5b7c1 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,8 +1,16 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker, DEFAULT_GRACE_PERIOD -from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ - DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig +from redis.multidb.circuit import ( + PBCircuitBreakerAdapter, + CircuitBreaker, + DEFAULT_GRACE_PERIOD, +) +from redis.multidb.config import ( + MultiDbConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, + DatabaseConfig, +) from redis.multidb.database import Database from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck @@ -13,14 +21,18 @@ class TestMultiDbConfig: def test_default_config(self): db_configs = [ - DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), - DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), - DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), - ] - - config = MultiDbConfig( - databases_config=db_configs - ) + DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ), + DatabaseConfig( + client_kwargs={"host": "host2", "port": "port2"}, weight=0.9 + ), + DatabaseConfig( + client_kwargs={"host": "host3", "port": "port3"}, weight=0.8 + ), + ] + + config = MultiDbConfig(databases_config=db_configs) assert config.databases_config == db_configs databases = config.databases() @@ -32,20 +44,26 @@ def test_default_config(self): assert weight == db_configs[i].weight assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD assert db.client.get_retry() is not config.command_retry - i+=1 + i += 1 assert len(config.default_failure_detectors()) == 1 assert isinstance(config.default_failure_detectors()[0], CommandFailureDetector) assert len(config.default_health_checks()) == 1 assert isinstance(config.default_health_checks()[0], EchoHealthCheck) assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL - assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert isinstance( + config.default_failover_strategy(), WeightBasedFailoverStrategy + ) assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL assert isinstance(config.command_retry, Retry) def test_overridden_config(self): grace_period = 2 - mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools = [ + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + ] mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} @@ -55,22 +73,31 @@ def test_overridden_config(self): mock_cb2.grace_period = grace_period mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period - mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] + mock_failure_detectors = [ + Mock(spec=FailureDetector), + Mock(spec=FailureDetector), + ] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] health_check_interval = 10 mock_failover_strategy = Mock(spec=FailoverStrategy) auto_fallback_interval = 10 db_configs = [ - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 - ), - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 - ), - DatabaseConfig( - client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 - ), - ] + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, + weight=1.0, + circuit=mock_cb1, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, + weight=0.9, + circuit=mock_cb2, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, + weight=0.8, + circuit=mock_cb3, + ), + ] config = MultiDbConfig( databases_config=db_configs, @@ -91,7 +118,7 @@ def test_overridden_config(self): assert weight == db_configs[i].weight assert db.client.connection_pool == mock_connection_pools[i] assert db.circuit.grace_period == grace_period - i+=1 + i += 1 assert len(config.failure_detectors) == 2 assert config.failure_detectors[0] == mock_failure_detectors[0] @@ -103,11 +130,14 @@ def test_overridden_config(self): assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval + class TestDatabaseConfig: def test_default_config(self): - config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + config = DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ) - assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.client_kwargs == {"host": "host1", "port": "port1"} assert config.weight == 1.0 assert isinstance(config.default_circuit_breaker(), PBCircuitBreakerAdapter) @@ -116,9 +146,11 @@ def test_overridden_config(self): mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( - client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + client_kwargs={"connection_pool": mock_connection_pool}, + weight=1.0, + circuit=mock_circuit, ) - assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.client_kwargs == {"connection_pool": mock_connection_pool} assert config.weight == 1.0 - assert config.circuit == mock_circuit \ No newline at end of file + assert config.circuit == mock_circuit diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 6ae6a9610c..1641c0ee63 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -4,26 +4,32 @@ from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState -from redis.multidb.exception import NoValidDatabaseException, TemporaryUnavailableException -from redis.multidb.failover import WeightBasedFailoverStrategy, DefaultFailoverStrategyExecutor +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) +from redis.multidb.failover import ( + WeightBasedFailoverStrategy, + DefaultFailoverStrategyExecutor, +) class TestWeightBasedFailoverStrategy: @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ( - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], - ids=['all closed - highest weight', 'highest weight - open'], + ids=["all closed - highest weight", "highest weight - open"], indirect=True, ) def test_get_valid_database(self, mock_db, mock_db1, mock_db2): @@ -38,12 +44,12 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): assert failover_strategy.database() == mock_db1 @pytest.mark.parametrize( - 'mock_db,mock_db1,mock_db2', + "mock_db,mock_db1,mock_db2", [ ( - {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, - {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, @@ -51,29 +57,33 @@ def test_get_valid_database(self, mock_db, mock_db1, mock_db2): def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): failover_strategy = WeightBasedFailoverStrategy() - with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + with pytest.raises( + NoValidDatabaseException, + match="No valid database available for communication", + ): assert failover_strategy.database() + class TestDefaultStrategyExecutor: @pytest.mark.parametrize( - 'mock_db', + "mock_db", [ - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, ], indirect=True, ) - def test_execute_returns_valid_database_with_failover_attempts(self, mock_db, mock_fs): + def test_execute_returns_valid_database_with_failover_attempts( + self, mock_db, mock_fs + ): failover_attempts = 3 mock_fs.database.side_effect = [ NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - mock_db + mock_db, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) for i in range(failover_attempts + 1): @@ -96,12 +106,10 @@ def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - NoValidDatabaseException + NoValidDatabaseException, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) with pytest.raises(NoValidDatabaseException): @@ -124,18 +132,19 @@ def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_f NoValidDatabaseException, NoValidDatabaseException, NoValidDatabaseException, - NoValidDatabaseException + NoValidDatabaseException, ] executor = DefaultFailoverStrategyExecutor( - mock_fs, - failover_attempts=failover_attempts, - failover_delay=0.1 + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 ) - with pytest.raises(TemporaryUnavailableException, match=( - "No database connections currently available. " - "This is a temporary condition - please retry the operation." - )): + with pytest.raises( + TemporaryUnavailableException, + match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ), + ): for i in range(failover_attempts + 1): try: executor.execute() @@ -147,4 +156,4 @@ def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_f if i == failover_attempts: raise e - assert mock_fs.database.call_count == 4 \ No newline at end of file + assert mock_fs.database.call_count == 4 diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 3e71ab6aa5..f77a9c5d5d 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -12,7 +12,7 @@ class TestCommandFailureDetector: @pytest.mark.parametrize( - 'min_num_failures,failure_rate_threshold,circuit_state', + "min_num_failures,failure_rate_threshold,circuit_state", [ (2, 0.4, CBState.OPEN), (2, 0, CBState.OPEN), @@ -29,10 +29,7 @@ class TestCommandFailureDetector: ], ) def test_failure_detector_correctly_reacts_to_failures( - self, - min_num_failures, - failure_rate_threshold, - circuit_state + self, min_num_failures, failure_rate_threshold, circuit_state ): fd = CommandFailureDetector(min_num_failures, failure_rate_threshold) mock_db = Mock(spec=Database) @@ -41,19 +38,19 @@ def test_failure_detector_correctly_reacts_to_failures( mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) - fd.register_command_execution(('GET', 'key')) - fd.register_command_execution(('GET','key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) - fd.register_command_execution(('GET', 'key')) - fd.register_command_execution(('GET','key')) - fd.register_command_execution(('GET','key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == circuit_state @pytest.mark.parametrize( - 'min_num_failures,failure_rate_threshold', + "min_num_failures,failure_rate_threshold", [ (3, 0.0), (3, 0.6), @@ -63,7 +60,9 @@ def test_failure_detector_correctly_reacts_to_failures( "do not exceeds min num failures AND failure rate, during interval", ], ) - def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min_num_failures, failure_rate_threshold): + def test_failure_detector_do_not_open_circuit_on_interval_exceed( + self, min_num_failures, failure_rate_threshold + ): fd = CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) mock_db = Mock(spec=Database) mock_db.circuit.state = CBState.CLOSED @@ -72,24 +71,24 @@ def test_failure_detector_do_not_open_circuit_on_interval_exceed(self, min_num_f fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_command_execution(('GET', 'key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) sleep(0.16) - fd.register_command_execution(('GET', 'key')) - fd.register_command_execution(('GET', 'key')) - fd.register_command_execution(('GET', 'key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) sleep(0.16) - fd.register_command_execution(('GET', 'key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == CBState.CLOSED # 2 more failure as last one already refreshed timer - fd.register_command_execution(('GET', 'key')) - fd.register_failure(Exception(), ('GET', 'key')) - fd.register_command_execution(('GET', 'key')) - fd.register_failure(Exception(), ('GET', 'key')) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) assert mock_db.circuit.state == CBState.OPEN @@ -102,16 +101,16 @@ def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(se fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) - fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + fd.register_failure(Exception(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(Exception(), ("SET", "key1", "value1")) + fd.register_failure(Exception(), ("SET", "key1", "value1")) assert mock_db.circuit.state == CBState.CLOSED - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) - fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) - assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file + assert mock_db.circuit.state == CBState.OPEN diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 43ad1ac888..684d5452c7 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -4,10 +4,18 @@ from redis.multidb.database import Database from redis.http.http_client import HttpError -from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck, HealthCheck, HealthyAllPolicy, \ - UnhealthyDatabaseException, HealthyMajorityPolicy, HealthyAnyPolicy +from redis.multidb.healthcheck import ( + EchoHealthCheck, + LagAwareHealthCheck, + HealthCheck, + HealthyAllPolicy, + UnhealthyDatabaseException, + HealthyMajorityPolicy, + HealthyAnyPolicy, +) from redis.multidb.circuit import State as CBState + class TestHealthyAllPolicy: def test_policy_returns_true_for_all_successful_probes(self): mock_hc1 = Mock(spec=HealthCheck) @@ -41,11 +49,12 @@ def test_policy_raise_unhealthy_database_exception(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 + class TestHealthyMajorityPolicy: @pytest.mark.parametrize( "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", @@ -62,20 +71,26 @@ class TestHealthyMajorityPolicy: (4, [False, True, True, True], [True, True, False, True], 4, 4, True), ], ids=[ - 'HC1 - no majority - odd', 'HC2 - no majority - odd', 'HC1 - majority- odd', - 'HC2 - majority - odd', 'HC1 + HC2 - majority - odd', 'HC1 - no majority - even', - 'HC2 - no majority - even','HC1 - majority - even', 'HC2 - majority - even', - 'HC1 + HC2 - majority - even' - ] + "HC1 - no majority - odd", + "HC2 - no majority - odd", + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 + HC2 - majority - odd", + "HC1 - no majority - even", + "HC2 - no majority - even", + "HC1 - majority - even", + "HC2 - majority - even", + "HC1 + HC2 - majority - even", + ], ) def test_policy_returns_true_for_majority_successful_probes( - self, - probes, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count, - expected_result + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -93,21 +108,30 @@ def test_policy_returns_true_for_majority_successful_probes( [ (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), - (4, [True, ConnectionError, ConnectionError, True], [True, True, True, True], 3, 0), - (4, [True, True, True, True], [True, ConnectionError, ConnectionError, False], 4, 3), + ( + 4, + [True, ConnectionError, ConnectionError, True], + [True, True, True, True], + 3, + 0, + ), + ( + 4, + [True, True, True, True], + [True, ConnectionError, ConnectionError, False], + 4, + 3, + ), ], ids=[ - 'HC1 - majority- odd', 'HC2 - majority - odd', - 'HC1 - majority - even', 'HC2 - majority - even', - ] + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 - majority - even", + "HC2 - majority - even", + ], ) def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( - self, - probes, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count + self, probes, hc1_side_effect, hc2_side_effect, hc1_call_count, hc2_call_count ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -116,11 +140,12 @@ def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == hc1_call_count assert mock_hc2.check_health.call_count == hc2_call_count + class TestHealthyAnyPolicy: @pytest.mark.parametrize( "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", @@ -131,17 +156,19 @@ class TestHealthyAnyPolicy: ([True, True, True], [False, True, False], 1, 2, True), ], ids=[ - 'HC1 - no successful', 'HC2 - no successful', - 'HC1 - successful', 'HC2 - successful', - ] + "HC1 - no successful", + "HC2 - no successful", + "HC1 - successful", + "HC2 - successful", + ], ) def test_policy_returns_true_for_any_successful_probe( - self, - hc1_side_effect, - hc2_side_effect, - hc1_call_count, - hc2_call_count, - expected_result + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) @@ -154,7 +181,9 @@ def test_policy_returns_true_for_any_successful_probe( assert mock_hc1.check_health.call_count == hc1_call_count assert mock_hc2.check_health.call_count == hc2_call_count - def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check(self): + def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check( + self, + ): mock_hc1 = Mock(spec=HealthCheck) mock_hc2 = Mock(spec=HealthCheck) mock_hc1.check_health.side_effect = [False, False, ConnectionError] @@ -162,38 +191,43 @@ def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed mock_db = Mock(spec=Database) policy = HealthyAnyPolicy(3, 0.01) - with pytest.raises(UnhealthyDatabaseException, match='Unhealthy database'): + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 + class TestEchoHealthCheck: def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.return_value = 'healthcheck' + mock_client.execute_command.return_value = "healthcheck" hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 1 - def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + def test_database_is_unhealthy_on_incorrect_echo_response( + self, mock_client, mock_cb + ): """ Mocking responses to mix error and actual responses to ensure that health check retry according to given configuration. """ - mock_client.execute_command.return_value = 'wrong' + mock_client.execute_command.return_value = "wrong" hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 1 - def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): - mock_client.execute_command.return_value = 'healthcheck' + def test_database_close_circuit_on_successful_healthcheck( + self, mock_client, mock_cb + ): + mock_client.execute_command.return_value = "healthcheck" mock_cb.state = CBState.HALF_OPEN hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) @@ -203,7 +237,9 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc class TestLagAwareHealthCheck: - def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + def test_database_is_healthy_when_bdb_matches_by_dns_name( + self, mock_client, mock_cb + ): """ Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name matches database host, and availability endpoint returns success. @@ -227,9 +263,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc None, ] - hc = LagAwareHealthCheck( - rest_api_port=1234, lag_aware_tolerance=150 - ) + hc = LagAwareHealthCheck(rest_api_port=1234, lag_aware_tolerance=150) # Inject our mocked http client hc._http_client = mock_http @@ -243,7 +277,10 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + assert ( + second_call.args[0] + == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + ) assert second_call.kwargs.get("expect_json") is False def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): @@ -273,7 +310,10 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + assert ( + mock_http.get.call_args_list[1].args[0] + == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + ) def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ @@ -285,8 +325,16 @@ def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): mock_http = MagicMock() # Return bdbs that do not match host by dns_name nor addr mock_http.get.return_value = [ - {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, - {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + { + "uid": "a", + "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}], + }, + { + "uid": "b", + "endpoints": [ + {"dns_name": "another.example.com", "addr": ["10.0.0.10"]} + ], + }, ] hc = LagAwareHealthCheck() @@ -312,7 +360,11 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb): mock_http.get.side_effect = [ [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], # Second: availability -> raise HttpError - HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + HttpError( + url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", + status=503, + message="busy", + ), ] hc = LagAwareHealthCheck() @@ -325,4 +377,4 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb): assert e.status == 503 # Ensure both calls were attempted - assert mock_http.get.call_count == 2 \ No newline at end of file + assert mock_http.get.call_count == 2 diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 54f6a4df17..4afbb2db35 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -7,38 +7,48 @@ from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.failover import WeightBasedFailoverStrategy, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.multidb.failover import ( + WeightBasedFailoverStrategy, + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, +) from redis.multidb.healthcheck import EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list + def mock_pipe() -> Pipeline: mock_pipe = Mock(spec=Pipeline) mock_pipe.__enter__ = Mock(return_value=mock_pipe) mock_pipe.__exit__ = Mock(return_value=None) return mock_pipe + class TestPipeline: @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_executes_pipeline_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): pipe = mock_pipe() - pipe.execute.return_value = ['OK1', 'value1'] + pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe mock_hc.check_health.return_value = True @@ -47,45 +57,57 @@ def test_executes_pipeline_against_correct_db( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 pipe = client.pipeline() - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert pipe.execute() == ['OK1', 'value1'] + assert pipe.execute() == ["OK1", "value1"] assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) def test_execute_pipeline_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): pipe = mock_pipe() - pipe.execute.return_value = ['OK1', 'value1'] + pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 with client.pipeline() as pipe: - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert pipe.execute() == ['OK1', 'value1'] + assert pipe.execute() == ["OK1", "value1"] assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -93,19 +115,19 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -121,23 +143,43 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] pipe = mock_pipe() - pipe.execute.return_value = ['OK', 'value'] + pipe.execute.return_value = ["OK", "value"] mock_db.client.pipeline.return_value = pipe pipe1 = mock_pipe() - pipe1.execute.return_value = ['OK1', 'value'] + pipe1.execute.return_value = ["OK1", "value"] mock_db1.client.pipeline.return_value = pipe1 pipe2 = mock_pipe() - pipe2.execute.return_value = ['OK2', 'value'] + pipe2.execute.return_value = ["OK2", "value"] mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 @@ -146,57 +188,61 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin client = MultiDBClient(mock_multi_db_config) with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert pipe.execute() == ['OK1', 'value'] + assert pipe.execute() == ["OK1", "value"] sleep(0.15) with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert pipe.execute() == ['OK2', 'value'] + assert pipe.execute() == ["OK2", "value"] sleep(0.1) with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert pipe.execute() == ['OK', 'value'] + assert pipe.execute() == ["OK", "value"] sleep(0.1) with client.pipeline() as pipe: - pipe.set('key1', 'value') - pipe.get('key1') + pipe.set("key1", "value") + pipe.get("key1") - assert pipe.execute() == ['OK1', 'value'] + assert pipe.execute() == ["OK1", "value"] -class TestTransaction: +class TestTransaction: @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_executes_transaction_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.transaction.return_value = ['OK1', 'value1'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] mock_hc.check_health.return_value = True @@ -204,43 +250,55 @@ def test_executes_transaction_against_correct_db( assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert client.transaction(callback) == ['OK1', 'value1'] + assert client.transaction(callback) == ["OK1", "value1"] assert mock_hc.check_health.call_count == 9 @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, ), ], indirect=True, ) def test_execute_transaction_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): - mock_db1.client.transaction.return_value = ['OK1', 'value1'] - - mock_hc.check_health.side_effect = [False, True, True, True, True, True, True] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert client.transaction(callback) == ['OK1', 'value1'] + assert client.transaction(callback) == ["OK1", "value1"] assert mock_hc.check_health.call_count == 7 assert mock_db.circuit.state == CBState.CLOSED @@ -248,19 +306,19 @@ def callback(pipe: Pipeline): assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.parametrize( - 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ ( - {"health_check_probes" : 1}, - {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, - {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, ), ], indirect=True, ) def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -276,15 +334,36 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object(mock_multi_db_config,'databases',return_value=databases), \ - patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck()]): - mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] - mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] - mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] - - mock_db.client.transaction.return_value = ['OK', 'value'] - mock_db1.client.transaction.return_value = ['OK1', 'value'] - mock_db2.client.transaction.return_value = ['OK2', 'value'] + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + mock_db.client.transaction.return_value = ["OK", "value"] + mock_db1.client.transaction.return_value = ["OK1", "value"] + mock_db2.client.transaction.return_value = ["OK2", "value"] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() @@ -292,13 +371,13 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter client = MultiDBClient(mock_multi_db_config) def callback(pipe: Pipeline): - pipe.set('key1', 'value1') - pipe.get('key1') + pipe.set("key1", "value1") + pipe.get("key1") - assert client.transaction(callback) == ['OK1', 'value'] + assert client.transaction(callback) == ["OK1", "value"] sleep(0.15) - assert client.transaction(callback) == ['OK2', 'value'] + assert client.transaction(callback) == ["OK2", "value"] sleep(0.1) - assert client.transaction(callback) == ['OK', 'value'] + assert client.transaction(callback) == ["OK", "value"] sleep(0.1) - assert client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file + assert client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index a74da0dbf7..e39ef88045 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -11,7 +11,11 @@ from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient -from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL +from redis.multidb.config import ( + DatabaseConfig, + MultiDbConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, +) from redis.multidb.event import ActiveDatabaseChanged from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY @@ -26,6 +30,7 @@ DEFAULT_ENDPOINT_NAME = "m-standard" + class CheckActiveDatabaseChangedListener(EventListenerInterface): def __init__(self): self.is_changed_flag = False @@ -33,6 +38,7 @@ def __init__(self): def listen(self, event: ActiveDatabaseChanged): self.is_changed_flag = True + @pytest.fixture() def endpoint_name(request): return request.config.getoption("--endpoint-name") or os.getenv( @@ -56,10 +62,12 @@ def get_endpoints_config(endpoint_name: str): f"Failed to load endpoints config file: {endpoints_config}" ) from e + @pytest.fixture() def endpoints_config(endpoint_name: str): return get_endpoints_config(endpoint_name) + @pytest.fixture() def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") @@ -67,66 +75,76 @@ def fault_injector_client(): @pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: - client_class = request.param.get('client_class', Redis) - - if client_class == Redis: - endpoint_config = get_endpoints_config('re-active-active') - else: - endpoint_config = get_endpoints_config('re-active-active-oss-cluster') - - username = endpoint_config.get('username', None) - password = endpoint_config.get('password', None) - min_num_failures = request.param.get('min_num_failures', DEFAULT_MIN_NUM_FAILURES) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10)) - - # Retry configuration different for health checks as initial health check require more time in case - # if infrastructure wasn't restored from the previous test. - health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) - health_check_delay = request.param.get('health_check_delay', DEFAULT_HEALTH_CHECK_DELAY) - event_dispatcher = EventDispatcher() - listener = CheckActiveDatabaseChangedListener() - event_dispatcher.register_listeners({ - ActiveDatabaseChanged: [listener], - }) - db_configs = [] - - db_config = DatabaseConfig( - weight=1.0, - from_url=endpoint_config['endpoints'][0], - client_kwargs={ - 'username': username, - 'password': password, - 'decode_responses': True, - }, - health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) - ) - db_configs.append(db_config) - - db_config1 = DatabaseConfig( - weight=0.9, - from_url=endpoint_config['endpoints'][1], - client_kwargs={ - 'username': username, - 'password': password, - 'decode_responses': True, - }, - health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) - ) - db_configs.append(db_config1) - - config = MultiDbConfig( - client_class=client_class, - databases_config=db_configs, - command_retry=command_retry, - min_num_failures=min_num_failures, - health_check_probes=3, - health_check_interval=health_check_interval, - event_dispatcher=event_dispatcher, - health_check_delay=health_check_delay, - ) - - return MultiDBClient(config), listener, endpoint_config +def r_multi_db( + request, +) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get("client_class", Redis) + + if client_class == Redis: + endpoint_config = get_endpoints_config("re-active-active") + else: + endpoint_config = get_endpoints_config("re-active-active-oss-cluster") + + username = endpoint_config.get("username", None) + password = endpoint_config.get("password", None) + min_num_failures = request.param.get("min_num_failures", DEFAULT_MIN_NUM_FAILURES) + command_retry = request.param.get( + "command_retry", Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10) + ) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get( + "health_check_interval", DEFAULT_HEALTH_CHECK_INTERVAL + ) + health_check_delay = request.param.get( + "health_check_delay", DEFAULT_HEALTH_CHECK_DELAY + ) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners( + { + ActiveDatabaseChanged: [listener], + } + ) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config["endpoints"][0], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][0]), + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config["endpoints"][1], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][1]), + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + min_num_failures=min_num_failures, + health_check_probes=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_delay=health_check_delay, + ) + + return MultiDBClient(config), listener, endpoint_config def extract_cluster_fqdn(url): @@ -142,11 +160,12 @@ def extract_cluster_fqdn(url): # Remove the 'redis-XXXX.' prefix using regex # This pattern matches 'redis-' followed by digits and a dot - cleaned_hostname = re.sub(r'^redis-\d+\.', '', hostname) + cleaned_hostname = re.sub(r"^redis-\d+\.", "", hostname) # Reconstruct the URL return f"https://{cleaned_hostname}" + @pytest.fixture() def client_maint_notifications(endpoints_config): return _get_client_maint_notifications(endpoints_config) @@ -221,4 +240,4 @@ def _get_client_maint_notifications( f"Maintenance notifications pool handler: {maintenance_handler_exists}" ) - return client \ No newline at end of file + return client diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index cca84a9bb1..59524ab5c1 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -18,27 +18,32 @@ logger = logging.getLogger(__name__) -def trigger_network_failure_action(fault_injector_client, config, event: threading.Event = None): + +def trigger_network_failure_action( + fault_injector_client, config, event: threading.Event = None +): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": config['bdb_id'], "delay": 3, "cluster_index": 0} + parameters={"bdb_id": config["bdb_id"], "delay": 3, "cluster_index": 0}, ) result = fault_injector_client.trigger_action(action_request) - status_result = fault_injector_client.get_action_status(result['action_id']) + status_result = fault_injector_client.get_action_status(result["action_id"]) - while status_result['status'] != "success": + while status_result["status"] != "success": sleep(0.1) - status_result = fault_injector_client.get_action_status(result['action_id']) - logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + status_result = fault_injector_client.get_action_status(result["action_id"]) + logger.info( + f"Waiting for action to complete. Status: {status_result['status']}" + ) if event: event.set() logger.info(f"Action completed. Status: {status_result['status']}") -class TestActiveActive: +class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. sleep(10) @@ -50,107 +55,121 @@ def teardown_method(self, method): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + def test_multi_db_client_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db # Handle unavailable databases from previous test. retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) # Client initialized on the first command. retry.call_with_retry( - lambda : r_multi_db.set('key', 'value'), - lambda _ : dummy_fail() + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail() ) thread.start() # Execute commands before network failure while not event.is_set(): - assert retry.call_with_retry( - lambda : r_multi_db.get('key'), - lambda _ : dummy_fail() - ) == 'value' + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) sleep(0.5) # Execute commands until database failover while not listener.is_changed_flag: - assert retry.call_with_retry( - lambda : r_multi_db.get('key'), - lambda _ : dummy_fail() - ) == 'value' + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ {"client_class": Redis, "min_num_failures": 2, "health_check_interval": 20}, - {"client_class": RedisCluster, "min_num_failures": 2, "health_check_interval": 20}, + { + "client_class": RedisCluster, + "min_num_failures": 2, + "health_check_interval": 20, + }, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + def test_multi_db_client_uses_lag_aware_health_check( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) - env0_username = os.getenv('ENV0_USERNAME') - env0_password = os.getenv('ENV0_PASSWORD') + env0_username = os.getenv("ENV0_USERNAME") + env0_password = os.getenv("ENV0_PASSWORD") # Adding additional health check to the client. r_multi_db.add_health_check( LagAwareHealthCheck( verify_tls=False, - auth_basic=(env0_username,env0_password), - lag_aware_tolerance=10000 + auth_basic=(env0_username, env0_password), + lag_aware_tolerance=10000, ) ) # Client initialized on the first command. retry.call_with_retry( - lambda : r_multi_db.set('key', 'value'), - lambda _ : dummy_fail() + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail() ) thread.start() # Execute commands before network failure while not event.is_set(): - assert retry.call_with_retry( - lambda : r_multi_db.get('key'), - lambda _ : dummy_fail() - ) == 'value' + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) sleep(0.5) # Execute commands after network failure while not listener.is_changed_flag: - assert retry.call_with_retry( - lambda : r_multi_db.get('key'), - lambda _ : dummy_fail() - ) == 'value' + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) sleep(0.5) @pytest.mark.parametrize( @@ -160,55 +179,55 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + def test_context_manager_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) def callback(): with r_multi_db.pipeline() as pipe: - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] # Client initialized on first pipe execution. - retry.call_with_retry( - lambda : callback(), - lambda _ : dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) thread.start() # Execute pipeline before network failure while not event.is_set(): - retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) sleep(0.5) @pytest.mark.parametrize( @@ -218,56 +237,49 @@ def callback(): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + def test_chaining_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) def callback(): pipe = r_multi_db.pipeline() - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') - assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert pipe.execute() == [True, True, True, "value1", "value2", "value3"] # Client initialized on first pipe execution. - retry.call_with_retry( - lambda : callback(), - lambda _ : dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) thread.start() # Execute pipeline before network failure while not event.is_set(): - retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) sleep(0.5) # Execute pipeline until database failover for _ in range(5): - retry.call_with_retry( - lambda: callback(), - lambda _: dummy_fail() - ) + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) sleep(0.5) @pytest.mark.parametrize( @@ -277,52 +289,51 @@ def callback(): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + def test_transaction_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) def callback(pipe: Pipeline): - pipe.set('{hash}key1', 'value1') - pipe.set('{hash}key2', 'value2') - pipe.set('{hash}key3', 'value3') - pipe.get('{hash}key1') - pipe.get('{hash}key2') - pipe.get('{hash}key3') + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") # Client initialized on first transaction execution. retry.call_with_retry( - lambda : r_multi_db.transaction(callback), - lambda _ : dummy_fail() + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() ) thread.start() # Execute transaction before network failure while not event.is_set(): retry.call_with_retry( - lambda: r_multi_db.transaction(callback), - lambda _: dummy_fail() + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() ) sleep(0.5) # Execute transaction until database failover while not listener.is_changed_flag: retry.call_with_retry( - lambda: r_multi_db.transaction(callback), - lambda _: dummy_fail() + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() ) sleep(0.5) @@ -333,7 +344,7 @@ def callback(pipe: Pipeline): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): @@ -341,16 +352,16 @@ def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) - data = json.dumps({'message': 'test'}) + data = json.dumps({"message": "test"}) messages_count = 0 def handler(message): @@ -361,8 +372,8 @@ def handler(message): # Assign a handler and run in a separate thread. retry.call_with_retry( - lambda: pubsub.subscribe(**{'test-channel': handler}), - lambda _: dummy_fail() + lambda: pubsub.subscribe(**{"test-channel": handler}), + lambda _: dummy_fail(), ) pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) thread.start() @@ -370,16 +381,14 @@ def handler(message): # Execute publish before network failure while not event.is_set(): retry.call_with_retry( - lambda: r_multi_db.publish('test-channel', data), - lambda _: dummy_fail() + lambda: r_multi_db.publish("test-channel", data), lambda _: dummy_fail() ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: retry.call_with_retry( - lambda: r_multi_db.publish('test-channel', data), - lambda _: dummy_fail() + lambda: r_multi_db.publish("test-channel", data), lambda _: dummy_fail() ) sleep(0.5) @@ -393,24 +402,26 @@ def handler(message): {"client_class": RedisCluster, "min_num_failures": 2}, ], ids=["standalone", "cluster"], - indirect=True + indirect=True, ) @pytest.mark.timeout(100) - def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + def test_sharded_pubsub_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): r_multi_db, listener, config = r_multi_db retry = Retry( supported_errors=(TemporaryUnavailableException,), retries=DEFAULT_FAILOVER_ATTEMPTS, - backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY) + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), ) event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,config,event) + args=(fault_injector_client, config, event), ) - data = json.dumps({'message': 'test'}) + data = json.dumps({"message": "test"}) messages_count = 0 def handler(message): @@ -421,31 +432,29 @@ def handler(message): # Assign a handler and run in a separate thread. retry.call_with_retry( - lambda: pubsub.ssubscribe(**{'test-channel': handler}), - lambda _: dummy_fail() + lambda: pubsub.ssubscribe(**{"test-channel": handler}), + lambda _: dummy_fail(), ) pubsub_thread = pubsub.run_in_thread( - sleep_time=0.1, - daemon=True, - sharded_pubsub=True + sleep_time=0.1, daemon=True, sharded_pubsub=True ) thread.start() # Execute publish before network failure while not event.is_set(): retry.call_with_retry( - lambda: r_multi_db.spublish('test-channel', data), - lambda _: dummy_fail() + lambda: r_multi_db.spublish("test-channel", data), + lambda _: dummy_fail(), ) sleep(0.5) # Execute publish until database failover while not listener.is_changed_flag: retry.call_with_retry( - lambda: r_multi_db.spublish('test-channel', data), - lambda _: dummy_fail() + lambda: r_multi_db.spublish("test-channel", data), + lambda _: dummy_fail(), ) sleep(0.5) pubsub_thread.stop() - assert messages_count > 2 \ No newline at end of file + assert messages_count > 2 From 57d328bc56369b3ceff409c9552e00b9f7656538 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 11:02:06 +0300 Subject: [PATCH 30/50] Codestyle changes --- redis/asyncio/cluster.py | 2 +- redis/asyncio/http/http_client.py | 5 ++- redis/asyncio/multidb/client.py | 13 ++++--- redis/asyncio/multidb/command_executor.py | 20 +++++----- redis/asyncio/multidb/config.py | 24 ++++++------ redis/asyncio/multidb/database.py | 2 +- redis/asyncio/multidb/failover.py | 4 +- redis/asyncio/multidb/healthcheck.py | 13 +++---- redis/background.py | 2 +- redis/cluster.py | 2 +- redis/data_structure.py | 2 +- redis/event.py | 2 +- redis/http/http_client.py | 5 +-- redis/multidb/circuit.py | 2 +- redis/multidb/client.py | 10 +++-- redis/multidb/command_executor.py | 12 +++--- redis/multidb/config.py | 37 +++++++++---------- redis/multidb/database.py | 5 +-- redis/multidb/event.py | 1 - redis/multidb/failover.py | 2 +- redis/multidb/healthcheck.py | 10 ++--- redis/retry.py | 2 +- .../test_multidb/test_healthcheck.py | 12 +++--- .../test_multidb/test_pipeline.py | 1 - tests/test_asyncio/test_scenario/conftest.py | 3 +- .../test_scenario/test_active_active.py | 2 - tests/test_data_structure.py | 4 +- tests/test_multidb/conftest.py | 1 - tests/test_multidb/test_command_executor.py | 1 - tests/test_multidb/test_healthcheck.py | 12 +++--- tests/test_multidb/test_pipeline.py | 2 - tests/test_scenario/conftest.py | 6 +-- 32 files changed, 105 insertions(+), 116 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 9810654626..d05de07d18 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2254,7 +2254,7 @@ async def _reinitialize_on_error(self, error): await self._pipe.cluster_client.nodes_manager.initialize() self.reinitialize_counter = 0 else: - if type(error) == MovedError: + if type(error) is MovedError: self._pipe.cluster_client.nodes_manager.update_moved_exception( error ) diff --git a/redis/asyncio/http/http_client.py b/redis/asyncio/http/http_client.py index 51e3ba9226..688b33b2e3 100644 --- a/redis/asyncio/http/http_client.py +++ b/redis/asyncio/http/http_client.py @@ -1,8 +1,9 @@ import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Mapping, Union, Any -from redis.http.http_client import HttpResponse, HttpClient +from typing import Any, Mapping, Optional, Union + +from redis.http.http_client import HttpClient, HttpResponse DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" DEFAULT_TIMEOUT = 30.0 diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 0354733b6d..c972b4833a 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,18 +1,19 @@ import asyncio import logging -from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable +from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union -from redis.asyncio.client import PubSubHandler +from redis.asyncio.client import PSWorkerThreadExcHandlerT, PubSubHandler from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy -from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.background import BackgroundScheduler -from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands +from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException -from redis.typing import KeyT, EncodableT, ChannelT +from redis.typing import ChannelT, EncodableT, KeyT logger = logging.getLogger(__name__) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index de9dd62a85..daf9bf339c 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,29 +1,29 @@ from abc import abstractmethod from asyncio import iscoroutinefunction from datetime import datetime -from typing import List, Optional, Callable, Any, Union, Awaitable +from typing import Any, Awaitable, Callable, List, Optional, Union from redis.asyncio import RedisCluster -from redis.asyncio.client import PubSub, Pipeline -from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database +from redis.asyncio.client import Pipeline, PubSub +from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases from redis.asyncio.multidb.event import ( AsyncActiveDatabaseChanged, + CloseConnectionOnActiveDatabaseChanged, RegisterCommandFailure, ResubscribeOnActiveDatabaseChanged, - CloseConnectionOnActiveDatabaseChanged, ) from redis.asyncio.multidb.failover import ( - AsyncFailoverStrategy, - FailoverStrategyExecutor, - DefaultFailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY, + AsyncFailoverStrategy, + DefaultFailoverStrategyExecutor, + FailoverStrategyExecutor, ) from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry -from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent -from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor +from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.typing import KeyT diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 7e114eff1d..71f69ad133 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -1,43 +1,43 @@ from dataclasses import dataclass, field -from typing import Optional, List, Type, Union +from typing import List, Optional, Type, Union import pybreaker from redis.asyncio import ConnectionPool, Redis, RedisCluster -from redis.asyncio.multidb.database import Databases, Database +from redis.asyncio.multidb.database import Database, Databases from redis.asyncio.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, AsyncFailoverStrategy, WeightBasedFailoverStrategy, - DEFAULT_FAILOVER_DELAY, - DEFAULT_FAILOVER_ATTEMPTS, ) from redis.asyncio.multidb.failure_detector import ( AsyncFailureDetector, FailureDetectorAsyncWrapper, ) from redis.asyncio.multidb.healthcheck import ( - HealthCheck, - EchoHealthCheck, + DEFAULT_HEALTH_CHECK_DELAY, DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_POLICY, DEFAULT_HEALTH_CHECK_PROBES, - DEFAULT_HEALTH_CHECK_DELAY, + EchoHealthCheck, + HealthCheck, HealthCheckPolicies, - DEFAULT_HEALTH_CHECK_POLICY, ) from redis.asyncio.retry import Retry from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList -from redis.event import EventDispatcherInterface, EventDispatcher +from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import ( + DEFAULT_GRACE_PERIOD, CircuitBreaker, PBCircuitBreakerAdapter, - DEFAULT_GRACE_PERIOD, ) from redis.multidb.failure_detector import ( - CommandFailureDetector, - DEFAULT_MIN_NUM_FAILURES, DEFAULT_FAILURE_RATE_THRESHOLD, DEFAULT_FAILURES_DETECTION_WINDOW, + DEFAULT_MIN_NUM_FAILURES, + CommandFailureDetector, ) DEFAULT_AUTO_FALLBACK_INTERVAL = 120 diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py index fd91991d60..ecf7a1b972 100644 --- a/redis/asyncio/multidb/database.py +++ b/redis/asyncio/multidb/database.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Union, Optional +from typing import Optional, Union from redis.asyncio import Redis, RedisCluster from redis.data_structure import WeightedList diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index 8fbcf66955..5b9202111e 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -1,9 +1,9 @@ import time -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from redis.asyncio.multidb.database import AsyncDatabase, Databases -from redis.multidb.circuit import State as CBState from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState from redis.multidb.exception import ( NoValidDatabaseException, TemporaryUnavailableException, diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py index efa765eff4..dcb787f6ed 100644 --- a/redis/asyncio/multidb/healthcheck.py +++ b/redis/asyncio/multidb/healthcheck.py @@ -2,15 +2,14 @@ import logging from abc import ABC, abstractmethod from enum import Enum -from typing import Optional, Tuple, Union, List - +from typing import List, Optional, Tuple, Union from redis.asyncio import Redis -from redis.asyncio.http.http_client import AsyncHTTPClientWrapper, DEFAULT_TIMEOUT -from redis.retry import Retry +from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper from redis.backoff import NoBackoff from redis.http.http_client import HttpClient from redis.multidb.exception import UnhealthyDatabaseException +from redis.retry import Retry DEFAULT_HEALTH_CHECK_PROBES = 3 DEFAULT_HEALTH_CHECK_INTERVAL = 5 @@ -85,7 +84,7 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: if not await health_check.check_health(database): return False except Exception as e: - raise UnhealthyDatabaseException(f"Unhealthy database", database, e) + raise UnhealthyDatabaseException("Unhealthy database", database, e) if attempt < self.health_check_probes - 1: await asyncio.sleep(self._health_check_delay) @@ -117,7 +116,7 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: allowed_unsuccessful_probes -= 1 if allowed_unsuccessful_probes <= 0: raise UnhealthyDatabaseException( - f"Unhealthy database", database, e + "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: @@ -148,7 +147,7 @@ async def execute(self, health_checks: List[HealthCheck], database) -> bool: is_healthy = False except Exception as e: exception = UnhealthyDatabaseException( - f"Unhealthy database", database, e + "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: diff --git a/redis/background.py b/redis/background.py index b6327b9fdd..2f0e434967 100644 --- a/redis/background.py +++ b/redis/background.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import Callable, Coroutine, Any +from typing import Any, Callable, Coroutine class BackgroundScheduler: diff --git a/redis/cluster.py b/redis/cluster.py index b0999d52df..41b188ef74 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3165,7 +3165,7 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - if type(error) == MovedError: + if type(error) is MovedError: self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/data_structure.py b/redis/data_structure.py index dc91e48650..0571e223ad 100644 --- a/redis/data_structure.py +++ b/redis/data_structure.py @@ -1,5 +1,5 @@ import threading -from typing import List, Any, TypeVar, Generic, Union +from typing import Any, Generic, List, TypeVar from redis.typing import Number diff --git a/redis/event.py b/redis/event.py index bccf1fbf0d..84fabb40f5 100644 --- a/redis/event.py +++ b/redis/event.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Union, Dict, Type +from typing import Dict, List, Optional, Type, Union from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 4f52290c00..7d9d5c4ad4 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -1,16 +1,15 @@ from __future__ import annotations import base64 +import gzip import json import ssl -import gzip import zlib from dataclasses import dataclass from typing import Any, Dict, Mapping, Optional, Tuple, Union +from urllib.error import HTTPError, URLError from urllib.parse import urlencode, urljoin from urllib.request import Request, urlopen -from urllib.error import URLError, HTTPError - __all__ = ["HttpClient", "HttpResponse", "HttpError", "DEFAULT_TIMEOUT"] diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 3a4d90eeb3..8af6cc32de 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -1,4 +1,4 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from enum import Enum from typing import Callable diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 229e1b1616..485174fc03 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -2,13 +2,15 @@ import threading from concurrent.futures import as_completed from concurrent.futures.thread import ThreadPoolExecutor -from typing import List, Any, Callable, Optional +from typing import Any, Callable, List, Optional from redis.background import BackgroundScheduler -from redis.commands import RedisModuleCommands, CoreCommands +from redis.client import PubSubWorkerThread +from redis.commands import CoreCommands, RedisModuleCommands +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import State as CBState from redis.multidb.command_executor import DefaultCommandExecutor -from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException from redis.multidb.failure_detector import FailureDetector diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 8ba5d43e7d..f8e6171bc8 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,24 +1,24 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Optional, Callable, Any +from typing import Any, Callable, List, Optional from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent +from redis.multidb.circuit import State as CBState from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases, SyncDatabase -from redis.multidb.circuit import State as CBState from redis.multidb.event import ( - RegisterCommandFailure, ActiveDatabaseChanged, - ResubscribeOnActiveDatabaseChanged, CloseConnectionOnActiveDatabaseChanged, + RegisterCommandFailure, + ResubscribeOnActiveDatabaseChanged, ) from redis.multidb.failover import ( - FailoverStrategy, - FailoverStrategyExecutor, DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY, DefaultFailoverStrategyExecutor, + FailoverStrategy, + FailoverStrategyExecutor, ) from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 9ee41e394f..4586263748 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -4,38 +4,37 @@ import pybreaker from typing_extensions import Optional -from redis import Redis, ConnectionPool -from redis import RedisCluster +from redis import ConnectionPool, Redis, RedisCluster from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface from redis.multidb.circuit import ( - PBCircuitBreakerAdapter, - CircuitBreaker, DEFAULT_GRACE_PERIOD, + CircuitBreaker, + PBCircuitBreakerAdapter, ) from redis.multidb.database import Database, Databases +from redis.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + FailoverStrategy, + WeightBasedFailoverStrategy, +) from redis.multidb.failure_detector import ( - FailureDetector, - CommandFailureDetector, - DEFAULT_MIN_NUM_FAILURES, - DEFAULT_FAILURES_DETECTION_WINDOW, DEFAULT_FAILURE_RATE_THRESHOLD, + DEFAULT_FAILURES_DETECTION_WINDOW, + DEFAULT_MIN_NUM_FAILURES, + CommandFailureDetector, + FailureDetector, ) from redis.multidb.healthcheck import ( - HealthCheck, - EchoHealthCheck, - DEFAULT_HEALTH_CHECK_PROBES, - DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_DELAY, - HealthCheckPolicies, + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_HEALTH_CHECK_POLICY, -) -from redis.multidb.failover import ( - FailoverStrategy, - WeightBasedFailoverStrategy, - DEFAULT_FAILOVER_ATTEMPTS, - DEFAULT_FAILOVER_DELAY, + DEFAULT_HEALTH_CHECK_PROBES, + EchoHealthCheck, + HealthCheck, + HealthCheckPolicies, ) from redis.retry import Retry diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 8c7d536a88..d46de99e2d 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -1,8 +1,7 @@ -import redis from abc import ABC, abstractmethod -from enum import Enum -from typing import Union, Optional +from typing import Optional, Union +import redis from redis import RedisCluster from redis.data_structure import WeightedList from redis.multidb.circuit import CircuitBreaker diff --git a/redis/multidb/event.py b/redis/multidb/event.py index e9e9827344..0ffeb7f66e 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,7 +1,6 @@ from typing import List from redis.client import Redis - from redis.event import EventListenerInterface, OnCommandsFailEvent from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index c373a3a6f0..c660eddbd3 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from redis.data_structure import WeightedList -from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState +from redis.multidb.database import Databases, SyncDatabase from redis.multidb.exception import ( NoValidDatabaseException, TemporaryUnavailableException, diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 919d4d5cbd..5deda82f24 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,8 +1,8 @@ import logging -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from enum import Enum from time import sleep -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union from redis import Redis from redis.backoff import NoBackoff @@ -83,7 +83,7 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: if not health_check.check_health(database): return False except Exception as e: - raise UnhealthyDatabaseException(f"Unhealthy database", database, e) + raise UnhealthyDatabaseException("Unhealthy database", database, e) if attempt < self.health_check_probes - 1: sleep(self._health_check_delay) @@ -115,7 +115,7 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: allowed_unsuccessful_probes -= 1 if allowed_unsuccessful_probes <= 0: raise UnhealthyDatabaseException( - f"Unhealthy database", database, e + "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: @@ -146,7 +146,7 @@ def execute(self, health_checks: List[HealthCheck], database) -> bool: is_healthy = False except Exception as e: exception = UnhealthyDatabaseException( - f"Unhealthy database", database, e + "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: diff --git a/redis/retry.py b/redis/retry.py index 3873cafab5..225e431eb2 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -7,10 +7,10 @@ Callable, Generic, Iterable, + Optional, Tuple, Type, TypeVar, - Optional, ) from redis.exceptions import ConnectionError, TimeoutError diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index 242446f3fb..623ca8e8fd 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -26,7 +26,7 @@ async def test_policy_returns_true_for_all_successful_probes(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - assert await policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert await policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 3 @@ -39,7 +39,7 @@ async def test_policy_returns_false_on_first_failed_probe(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - assert await policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert not await policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 @@ -215,7 +215,7 @@ async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert await hc.check_health(db) == True + assert await hc.check_health(db) assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio @@ -230,7 +230,7 @@ async def test_database_is_unhealthy_on_incorrect_echo_response( hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert await hc.check_health(db) == False + assert not await hc.check_health(db) assert mock_client.execute_command.call_count == 1 @pytest.mark.asyncio @@ -242,7 +242,7 @@ async def test_database_close_circuit_on_successful_healthcheck( hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert await hc.check_health(db) == True + assert await hc.check_health(db) assert mock_client.execute_command.call_count == 1 @@ -283,7 +283,7 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name( assert await hc.check_health(db) is True # Base URL must be set correctly assert ( - hc._http_client.client.base_url == f"https://healthcheck.example.com:1234" + hc._http_client.client.base_url == "https://healthcheck.example.com:1234" ) # Calls: first to list bdbs, then to availability assert mock_http.get.call_count == 2 diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index da3a6a1737..48990bc62a 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -8,7 +8,6 @@ from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.healthcheck import EchoHealthCheck -from redis.asyncio.retry import Retry from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from tests.test_asyncio.test_multidb.conftest import create_weighted_list diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 67c9c829c3..803445f508 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -5,10 +5,9 @@ import pytest import pytest_asyncio -from redis.asyncio import Redis, RedisCluster +from redis.asyncio import Redis from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.config import ( - DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, MultiDbConfig, ) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 208084daf9..c33e482050 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -2,13 +2,11 @@ import json import logging import os -from time import sleep import pytest from redis.asyncio import RedisCluster from redis.asyncio.client import Pipeline, Redis -from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.failover import ( DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY, diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py index dd120d94d7..0911466e58 100644 --- a/tests/test_data_structure.py +++ b/tests/test_data_structure.py @@ -79,8 +79,8 @@ def worker(worker_id): try: length = len(wl) if length > 0: - top_items = wl.get_top_n(min(5, length)) - items_in_range = wl.get_by_weight_range(20, 80) + wl.get_top_n(min(5, length)) + wl.get_by_weight_range(20, 80) except Exception as e: print(f"Error in worker {worker_id}: {e}") diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 27fa5475cd..f6ee6d3ec4 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -19,7 +19,6 @@ DEFAULT_HEALTH_CHECK_PROBES, DEFAULT_HEALTH_CHECK_POLICY, ) -from tests.conftest import mock_ed @pytest.fixture() diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index c27802cf09..10dbb5dfc7 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -1,5 +1,4 @@ from time import sleep -from unittest.mock import PropertyMock import pytest diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 684d5452c7..d82388d13a 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -25,7 +25,7 @@ def test_policy_returns_true_for_all_successful_probes(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - assert policy.execute([mock_hc1, mock_hc2], mock_db) == True + assert policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 3 @@ -37,7 +37,7 @@ def test_policy_returns_false_on_first_failed_probe(self): mock_db = Mock(spec=Database) policy = HealthyAllPolicy(3, 0.01) - assert policy.execute([mock_hc1, mock_hc2], mock_db) == False + assert not policy.execute([mock_hc1, mock_hc2], mock_db) assert mock_hc1.check_health.call_count == 3 assert mock_hc2.check_health.call_count == 0 @@ -207,7 +207,7 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert hc.check_health(db) == True + assert hc.check_health(db) assert mock_client.execute_command.call_count == 1 def test_database_is_unhealthy_on_incorrect_echo_response( @@ -221,7 +221,7 @@ def test_database_is_unhealthy_on_incorrect_echo_response( hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert hc.check_health(db) == False + assert not hc.check_health(db) assert mock_client.execute_command.call_count == 1 def test_database_close_circuit_on_successful_healthcheck( @@ -232,7 +232,7 @@ def test_database_close_circuit_on_successful_healthcheck( hc = EchoHealthCheck() db = Database(mock_client, mock_cb, 0.9) - assert hc.check_health(db) == True + assert hc.check_health(db) assert mock_client.execute_command.call_count == 1 @@ -271,7 +271,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name( assert hc.check_health(db) is True # Base URL must be set correctly - assert hc._http_client.base_url == f"https://healthcheck.example.com:1234" + assert hc._http_client.base_url == "https://healthcheck.example.com:1234" # Calls: first to list bdbs, then to availability assert mock_http.get.call_count == 2 first_call = mock_http.get.call_args_list[0] diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 4afbb2db35..e8278de1a3 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -9,8 +9,6 @@ from redis.multidb.client import MultiDBClient from redis.multidb.failover import ( WeightBasedFailoverStrategy, - DEFAULT_FAILOVER_ATTEMPTS, - DEFAULT_FAILOVER_DELAY, ) from redis.multidb.healthcheck import EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index e39ef88045..5f568aa84e 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -7,7 +7,6 @@ import pytest -from redis import Redis from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient @@ -18,8 +17,8 @@ ) from redis.multidb.event import ActiveDatabaseChanged from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES -from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_DELAY -from redis.backoff import ExponentialWithJitterBackoff, NoBackoff +from redis.multidb.healthcheck import DEFAULT_HEALTH_CHECK_DELAY +from redis.backoff import ExponentialWithJitterBackoff from redis.client import Redis from redis.maint_notifications import EndpointType, MaintNotificationsConfig from redis.retry import Retry @@ -156,7 +155,6 @@ def extract_cluster_fqdn(url): # Extract hostname and port hostname = parsed.hostname - port = parsed.port # Remove the 'redis-XXXX.' prefix using regex # This pattern matches 'redis-' followed by digits and a dot From 18b5b52977f78c8be78a5ed3bc8523159e9b22e6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 11:02:22 +0300 Subject: [PATCH 31/50] Skip async scenario tests --- tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 20f9f245aa..d00441a5cf 100644 --- a/tasks.py +++ b/tasks.py @@ -58,11 +58,11 @@ def standalone_tests( if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" ) else: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" ) From 2ad005a5e3bcd7e3af060af537c53cc5ed92a227 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 11:06:26 +0300 Subject: [PATCH 32/50] Codestyle change --- tests/test_asyncio/test_multidb/test_healthcheck.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index 623ca8e8fd..7f17332197 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -282,9 +282,7 @@ async def test_database_is_healthy_when_bdb_matches_by_dns_name( assert await hc.check_health(db) is True # Base URL must be set correctly - assert ( - hc._http_client.client.base_url == "https://healthcheck.example.com:1234" - ) + assert hc._http_client.client.base_url == "https://healthcheck.example.com:1234" # Calls: first to list bdbs, then to availability assert mock_http.get.call_count == 2 first_call = mock_http.get.call_args_list[0] From 2672fceb3f32a0fbb1b70611d3482df6ccbb41b2 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 11:16:16 +0300 Subject: [PATCH 33/50] Fixed unused arguments --- redis/client.py | 4 ++-- tests/test_asyncio/test_connection_pool.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index c8e2ecac72..a29d310742 100755 --- a/redis/client.py +++ b/redis/client.py @@ -636,7 +636,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) - def _close_connection(self, conn, error, *args) -> None: + def _close_connection(self, conn) -> None: """ Close the connection before retrying. @@ -666,7 +666,7 @@ def _execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._close_connection(conn, error, *args), + lambda _: self._close_connection(conn), ) finally: diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index c30220fb1d..cb3dac9604 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -114,6 +114,9 @@ def set_re_auth_token(self, token: TokenInterface): async def re_auth(self): pass + def should_reconnect(self): + return False + class TestConnectionPool: @asynccontextmanager From be5d2e897eecc1ee9a83358e3226ace57543de9e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 12:00:54 +0300 Subject: [PATCH 34/50] Refactored bg scheduler --- redis/asyncio/multidb/client.py | 6 +- redis/asyncio/multidb/command_executor.py | 10 ++- redis/background.py | 85 +++++++++++++++++++++-- redis/event.py | 10 ++- tests/test_event.py | 4 +- 5 files changed, 98 insertions(+), 17 deletions(-) diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index c972b4833a..6bea588196 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -2,7 +2,7 @@ import logging from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union -from redis.asyncio.client import PSWorkerThreadExcHandlerT, PubSubHandler +from redis.asyncio.client import PubSubHandler from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig from redis.asyncio.multidb.database import AsyncDatabase, Databases @@ -507,7 +507,7 @@ async def get_message( async def run( self, *, - exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + exception_handler=None, poll_timeout: float = 1.0, ) -> None: """Process pub/sub messages using registered callbacks. @@ -524,5 +524,5 @@ async def run( >>> await task """ return await self._client.command_executor.execute_pubsub_run( - exception_handler=exception_handler, sleep_time=poll_timeout, pubsub=self + sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self ) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index daf9bf339c..c09b8b9969 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -266,9 +266,15 @@ async def callback(): return await self._execute_with_failure_detection(callback, *args) - async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def execute_pubsub_run( + self, sleep_time: float, exception_handler=None, pubsub=None + ) -> Any: async def callback(): - return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs) + return await self._active_pubsub.run( + poll_timeout=sleep_time, + exception_handler=exception_handler, + pubsub=pubsub, + ) return await self._execute_with_failure_detection(callback) diff --git a/redis/background.py b/redis/background.py index 2f0e434967..7d0eead11e 100644 --- a/redis/background.py +++ b/redis/background.py @@ -10,17 +10,47 @@ class BackgroundScheduler: def __init__(self): self._next_timer = None + self._event_loops = [] + self._lock = threading.Lock() + self._stopped = False def __del__(self): - if self._next_timer: - self._next_timer.cancel() + self.stop() + + def stop(self): + """ + Stop all scheduled tasks and clean up resources. + """ + with self._lock: + if self._stopped: + return + self._stopped = True + + if self._next_timer: + self._next_timer.cancel() + self._next_timer = None + + # Stop all event loops + for loop in self._event_loops: + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop) + + self._event_loops.clear() def run_once(self, delay: float, callback: Callable, *args): """ Runs callable task once after certain delay in seconds. """ + with self._lock: + if self._stopped: + return + # Run loop in a separate thread to unblock main thread. loop = asyncio.new_event_loop() + + with self._lock: + self._event_loops.append(loop) + thread = threading.Thread( target=_start_event_loop_in_thread, args=(loop, self._call_later, delay, callback, *args), @@ -32,9 +62,16 @@ def run_recurring(self, interval: float, callback: Callable, *args): """ Runs recurring callable task with given interval in seconds. """ + with self._lock: + if self._stopped: + return + # Run loop in a separate thread to unblock main thread. loop = asyncio.new_event_loop() + with self._lock: + self._event_loops.append(loop) + thread = threading.Thread( target=_start_event_loop_in_thread, args=(loop, self._call_later_recurring, interval, callback, *args), @@ -49,10 +86,17 @@ async def run_recurring_async( Runs recurring coroutine with given interval in seconds in the current event loop. To be used only from an async context. No additional threads are created. """ + with self._lock: + if self._stopped: + return + loop = asyncio.get_running_loop() wrapped = _async_to_sync_wrapper(loop, coro, *args) def tick(): + with self._lock: + if self._stopped: + return # Schedule the coroutine wrapped() # Schedule next tick @@ -64,6 +108,9 @@ def tick(): def _call_later( self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args ): + with self._lock: + if self._stopped: + return self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( @@ -73,6 +120,9 @@ def _call_later_recurring( callback: Callable, *args, ): + with self._lock: + if self._stopped: + return self._call_later( loop, interval, self._execute_recurring, loop, interval, callback, *args ) @@ -87,7 +137,19 @@ def _execute_recurring( """ Executes recurring callable task with given interval in seconds. """ - callback(*args) + with self._lock: + if self._stopped: + return + + try: + callback(*args) + except Exception: + # Silently ignore exceptions during shutdown + pass + + with self._lock: + if self._stopped: + return self._call_later( loop, interval, self._execute_recurring, loop, interval, callback, *args @@ -106,7 +168,22 @@ def _start_event_loop_in_thread( """ asyncio.set_event_loop(event_loop) event_loop.call_soon(call_soon_cb, event_loop, *args) - event_loop.run_forever() + try: + event_loop.run_forever() + finally: + try: + # Clean up pending tasks + pending = asyncio.all_tasks(event_loop) + for task in pending: + task.cancel() + # Run loop once more to process cancellations + event_loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass + finally: + event_loop.close() def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): diff --git a/redis/event.py b/redis/event.py index 84fabb40f5..cecd86293b 100644 --- a/redis/event.py +++ b/redis/event.py @@ -117,24 +117,22 @@ async def dispatch_async(self, event: object): def register_listeners( self, - event_listeners: Dict[ + mappings: Dict[ Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]], ], ): with self._lock: - for event_type in event_listeners: + for event_type in mappings: if event_type in self._event_listeners_mapping: self._event_listeners_mapping[event_type] = list( set( self._event_listeners_mapping[event_type] - + event_listeners[event_type] + + mappings[event_type] ) ) else: - self._event_listeners_mapping[event_type] = event_listeners[ - event_type - ] + self._event_listeners_mapping[event_type] = mappings[event_type] class AfterConnectionReleasedEvent: diff --git a/tests/test_event.py b/tests/test_event.py index f090251295..0caab04e78 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -31,7 +31,7 @@ def callback(event): mock_another_event_listener = Mock(spec=EventListenerInterface) mock_another_event_listener.listen = callback dispatcher.register_listeners( - event_listeners={type(mock_event): [mock_another_event_listener]} + mappings={type(mock_event): [mock_another_event_listener]} ) dispatcher.dispatch(mock_event) @@ -60,7 +60,7 @@ async def callback(event): mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) mock_another_event_listener.listen = callback dispatcher.register_listeners( - event_listeners={type(mock_event): [mock_another_event_listener]} + mappings={type(mock_event): [mock_another_event_listener]} ) await dispatcher.dispatch_async(mock_event) From bca7a31dcf068362ade92fc74ff7355233c84b52 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 14:54:36 +0300 Subject: [PATCH 35/50] Fixed tests --- redis/event.py | 5 ++++- tasks.py | 4 ++-- tests/test_maint_notifications_handling.py | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/redis/event.py b/redis/event.py index cecd86293b..03c72c6370 100644 --- a/redis/event.py +++ b/redis/event.py @@ -96,7 +96,7 @@ def __init__( } self._lock = threading.Lock() - self._async_lock = asyncio.Lock() + self._async_lock = None if event_listeners: self.register_listeners(event_listeners) @@ -109,6 +109,9 @@ def dispatch(self, event: object): listener.listen(event) async def dispatch_async(self, event: object): + if self._async_lock is None: + self._async_lock = asyncio.Lock() + async with self._async_lock: listeners = self._event_listeners_mapping.get(type(event), []) diff --git a/tasks.py b/tasks.py index d00441a5cf..d63bd8c92d 100644 --- a/tasks.py +++ b/tasks.py @@ -74,11 +74,11 @@ def cluster_tests(c, uvloop=False, protocol=2, profile=False): cluster_tls_url = "rediss://localhost:27379/0" if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" + f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" ) else: run( - f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" ) diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index baa7d601fa..54b6e2dff7 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -330,6 +330,9 @@ def setsockopt(self, level, optname, value): """Simulate setting socket options.""" pass + def setblocking(self, blocking): + pass + def getpeername(self): """Simulate getting peer name.""" return self.address From d9ad720211d50e9387e6e58507a5555b579168f1 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 15:34:27 +0300 Subject: [PATCH 36/50] Fixed tests --- redis/asyncio/cluster.py | 2 +- redis/cluster.py | 2 +- .../test_asyncio/test_multidb/test_client.py | 47 ++++++++-------- tests/test_multidb/test_client.py | 55 +++++++++---------- 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index d05de07d18..225fd3b79f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2254,7 +2254,7 @@ async def _reinitialize_on_error(self, error): await self._pipe.cluster_client.nodes_manager.initialize() self.reinitialize_counter = 0 else: - if type(error) is MovedError: + if isinstance(error, AskError): self._pipe.cluster_client.nodes_manager.update_moved_exception( error ) diff --git a/redis/cluster.py b/redis/cluster.py index 41b188ef74..1d4a3e0d0c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3165,7 +3165,7 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - if type(error) is MovedError: + if isinstance(error, AskError): self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 76bee8b3e6..bd4199a315 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -184,41 +184,38 @@ async def test_execute_command_against_correct_db_on_background_health_check_det indirect=True, ) async def test_execute_command_auto_fallback_to_highest_weight_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_hc.check_health.side_effect = [ + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True + ] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "healthcheck", - "healthcheck", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "OK1", - "error", - "healthcheck", - "healthcheck", - "OK1", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "OK2", - "healthcheck", - "healthcheck", - "healthcheck", - ] + mock_db.client.execute_command.return_value = 'OK' + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index dab80f2ba4..34f6cbaeaa 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -180,50 +180,47 @@ def test_execute_command_against_correct_db_on_background_health_check_determine indirect=True, ) def test_execute_command_auto_fallback_to_highest_weight_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_hc.check_health.side_effect = [ + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True + ] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "healthcheck", - "healthcheck", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "OK1", - "error", - "healthcheck", - "healthcheck", - "OK1", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "OK2", - "healthcheck", - "healthcheck", - "healthcheck", - ] - mock_multi_db_config.health_check_interval = 0.2 - mock_multi_db_config.auto_fallback_interval = 0.4 + mock_db.client.execute_command.return_value = 'OK' + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.30) + sleep(0.15) assert client.set("key", "value") == "OK2" - sleep(0.44) + sleep(0.22) assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( From 031b705b2021633fe5f1ced59684537a201dd777 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 15:35:43 +0300 Subject: [PATCH 37/50] Codestyle fixes --- tests/test_asyncio/test_multidb/test_client.py | 8 ++++---- tests/test_multidb/test_client.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index bd4199a315..1174384bf9 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -202,7 +202,7 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( True, True, True, - True + True, ] with ( @@ -213,9 +213,9 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( return_value=[mock_hc], ), ): - mock_db.client.execute_command.return_value = 'OK' - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 34f6cbaeaa..908c08a359 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -198,7 +198,7 @@ def test_execute_command_auto_fallback_to_highest_weight_db( True, True, True, - True + True, ] with ( @@ -209,9 +209,9 @@ def test_execute_command_auto_fallback_to_highest_weight_db( return_value=[mock_hc], ), ): - mock_db.client.execute_command.return_value = 'OK' - mock_db1.client.execute_command.return_value = 'OK1' - mock_db2.client.execute_command.return_value = 'OK2' + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() From e3bfa7890b658d6daf3e83afe12a257a3dc85499 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 15:55:56 +0300 Subject: [PATCH 38/50] Reduce timeouts to avoid overlaping with healthcheck --- tests/test_asyncio/test_multidb/test_client.py | 2 +- tests/test_multidb/test_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 1174384bf9..39d75f1a56 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -222,7 +222,7 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert await client.set("key", "value") == "OK1" - await asyncio.sleep(0.15) + await asyncio.sleep(0.12) assert await client.set("key", "value") == "OK2" await asyncio.sleep(0.22) assert await client.set("key", "value") == "OK1" diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 908c08a359..62e4233b92 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -218,7 +218,7 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.15) + sleep(0.12) assert client.set("key", "value") == "OK2" sleep(0.22) assert client.set("key", "value") == "OK1" From 5d5ff26bbb0069be678dee3ecfcfdbddc8505da6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 16:48:07 +0300 Subject: [PATCH 39/50] Marked tests non-clsuter only --- tests/test_asyncio/test_multidb/test_client.py | 1 + tests/test_asyncio/test_multidb/test_command_executor.py | 1 + tests/test_asyncio/test_multidb/test_config.py | 4 ++++ tests/test_asyncio/test_multidb/test_failover.py | 1 + tests/test_asyncio/test_multidb/test_failure_detector.py | 1 + tests/test_asyncio/test_multidb/test_healthcheck.py | 5 +++++ tests/test_multidb/test_circuit.py | 1 + tests/test_multidb/test_client.py | 1 + tests/test_multidb/test_command_executor.py | 1 + tests/test_multidb/test_config.py | 4 ++++ tests/test_multidb/test_failover.py | 2 ++ tests/test_multidb/test_failure_detector.py | 1 + tests/test_multidb/test_healthcheck.py | 5 +++++ tests/test_multidb/test_pipeline.py | 1 + 14 files changed, 29 insertions(+) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 39d75f1a56..0eb8e1fb7f 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -15,6 +15,7 @@ from tests.test_asyncio.test_multidb.conftest import create_weighted_list +@pytest.mark.onlynoncluster class TestMultiDbClient: @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py index b104b90e85..e0ac80a56a 100644 --- a/tests/test_asyncio/test_multidb/test_command_executor.py +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -14,6 +14,7 @@ from tests.test_asyncio.test_multidb.conftest import create_weighted_list +@pytest.mark.onlynoncluster class TestDefaultCommandExecutor: @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py index 76ccd29a06..d05c7a8a12 100644 --- a/tests/test_asyncio/test_multidb/test_config.py +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -1,5 +1,7 @@ from unittest.mock import Mock +import pytest + from redis.asyncio import ConnectionPool from redis.asyncio.multidb.config import ( DatabaseConfig, @@ -22,6 +24,7 @@ from redis.multidb.circuit import CircuitBreaker +@pytest.mark.onlynoncluster class TestMultiDbConfig: def test_default_config(self): db_configs = [ @@ -137,6 +140,7 @@ def test_overridden_config(self): assert config.auto_fallback_interval == auto_fallback_interval +@pytest.mark.onlynoncluster class TestDatabaseConfig: def test_default_config(self): config = DatabaseConfig( diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index 22d27f6369..a34bb368c8 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -14,6 +14,7 @@ ) +@pytest.mark.onlynoncluster class TestAsyncWeightBasedFailoverStrategy: @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py index 0d18c9137c..279bda9605 100644 --- a/tests/test_asyncio/test_multidb/test_failure_detector.py +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -10,6 +10,7 @@ from redis.multidb.failure_detector import CommandFailureDetector +@pytest.mark.onlynoncluster class TestFailureDetectorAsyncWrapper: @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py index 7f17332197..3e7ac42cd9 100644 --- a/tests/test_asyncio/test_multidb/test_healthcheck.py +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -16,6 +16,7 @@ from redis.multidb.exception import UnhealthyDatabaseException +@pytest.mark.onlynoncluster class TestHealthyAllPolicy: @pytest.mark.asyncio async def test_policy_returns_true_for_all_successful_probes(self): @@ -58,6 +59,7 @@ async def test_policy_raise_unhealthy_database_exception(self): assert mock_hc2.check_health.call_count == 0 +@pytest.mark.onlynoncluster class TestHealthyMajorityPolicy: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -151,6 +153,7 @@ async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exce assert mock_hc2.check_health.call_count == hc2_call_count +@pytest.mark.onlynoncluster class TestHealthyAnyPolicy: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -204,6 +207,7 @@ async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_ assert mock_hc2.check_health.call_count == 0 +@pytest.mark.onlynoncluster class TestEchoHealthCheck: @pytest.mark.asyncio async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): @@ -246,6 +250,7 @@ async def test_database_close_circuit_on_successful_healthcheck( assert mock_client.execute_command.call_count == 1 +@pytest.mark.onlynoncluster class TestLagAwareHealthCheck: @pytest.mark.asyncio async def test_database_is_healthy_when_bdb_matches_by_dns_name( diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 9bf221ec52..7d0f2cb700 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -8,6 +8,7 @@ ) +@pytest.mark.onlynoncluster class TestPBCircuitBreaker: @pytest.mark.parametrize( "mock_db", diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 62e4233b92..657509fda0 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -15,6 +15,7 @@ from tests.test_multidb.conftest import create_weighted_list +@pytest.mark.onlynoncluster class TestMultiDbClient: @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py index 10dbb5dfc7..43e5f47344 100644 --- a/tests/test_multidb/test_command_executor.py +++ b/tests/test_multidb/test_command_executor.py @@ -12,6 +12,7 @@ from tests.test_multidb.conftest import create_weighted_list +@pytest.mark.onlynoncluster class TestDefaultCommandExecutor: @pytest.mark.parametrize( "mock_db,mock_db1,mock_db2", diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index a63ac5b7c1..ea81f71ac9 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,4 +1,7 @@ from unittest.mock import Mock + +import pytest + from redis.connection import ConnectionPool from redis.multidb.circuit import ( PBCircuitBreakerAdapter, @@ -18,6 +21,7 @@ from redis.retry import Retry +@pytest.mark.onlynoncluster class TestMultiDbConfig: def test_default_config(self): db_configs = [ diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py index 1641c0ee63..60b231ab40 100644 --- a/tests/test_multidb/test_failover.py +++ b/tests/test_multidb/test_failover.py @@ -14,6 +14,7 @@ ) +@pytest.mark.onlynoncluster class TestWeightBasedFailoverStrategy: @pytest.mark.parametrize( "mock_db,mock_db1,mock_db2", @@ -64,6 +65,7 @@ def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): assert failover_strategy.database() +@pytest.mark.onlynoncluster class TestDefaultStrategyExecutor: @pytest.mark.parametrize( "mock_db", diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index f77a9c5d5d..b64ff601d2 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -10,6 +10,7 @@ from redis.exceptions import ConnectionError +@pytest.mark.onlynoncluster class TestCommandFailureDetector: @pytest.mark.parametrize( "min_num_failures,failure_rate_threshold,circuit_state", diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index d82388d13a..fb1f1e4148 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -16,6 +16,7 @@ from redis.multidb.circuit import State as CBState +@pytest.mark.onlynoncluster class TestHealthyAllPolicy: def test_policy_returns_true_for_all_successful_probes(self): mock_hc1 = Mock(spec=HealthCheck) @@ -55,6 +56,7 @@ def test_policy_raise_unhealthy_database_exception(self): assert mock_hc2.check_health.call_count == 0 +@pytest.mark.onlynoncluster class TestHealthyMajorityPolicy: @pytest.mark.parametrize( "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", @@ -146,6 +148,7 @@ def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions assert mock_hc2.check_health.call_count == hc2_call_count +@pytest.mark.onlynoncluster class TestHealthyAnyPolicy: @pytest.mark.parametrize( "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", @@ -197,6 +200,7 @@ def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed assert mock_hc2.check_health.call_count == 0 +@pytest.mark.onlynoncluster class TestEchoHealthCheck: def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ @@ -236,6 +240,7 @@ def test_database_close_circuit_on_successful_healthcheck( assert mock_client.execute_command.call_count == 1 +@pytest.mark.onlynoncluster class TestLagAwareHealthCheck: def test_database_is_healthy_when_bdb_matches_by_dns_name( self, mock_client, mock_cb diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index e8278de1a3..c3a494dd95 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -21,6 +21,7 @@ def mock_pipe() -> Pipeline: return mock_pipe +@pytest.mark.onlynoncluster class TestPipeline: @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", From cb51f8a904ae81b21d91299b23a786ad962680c1 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 16:49:46 +0300 Subject: [PATCH 40/50] Update timeouts --- tests/test_asyncio/test_multidb/test_client.py | 2 +- tests/test_multidb/test_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 0eb8e1fb7f..a9252fca54 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -223,7 +223,7 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert await client.set("key", "value") == "OK1" - await asyncio.sleep(0.12) + await asyncio.sleep(0.13) assert await client.set("key", "value") == "OK2" await asyncio.sleep(0.22) assert await client.set("key", "value") == "OK1" diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 657509fda0..7d2c363136 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -219,7 +219,7 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.12) + sleep(0.13) assert client.set("key", "value") == "OK2" sleep(0.22) assert client.set("key", "value") == "OK1" From 7f7ea768e697eb2216487ffe75e6a764c7cf41f2 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 17:45:37 +0300 Subject: [PATCH 41/50] Skip scenario tests --- .github/workflows/install_and_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index c90027389c..85cb07cb8a 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -40,7 +40,7 @@ cd ${TESTDIR} # install, run tests pip install ${PKG} # Redis tests -pytest -m 'not onlycluster' --ignore=tests/test_scenario +pytest -m 'not onlycluster' --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario # RedisCluster tests CLUSTER_URL="redis://localhost:16379/0" CLUSTER_SSL_URL="rediss://localhost:27379/0" From 933b3ad2bd027e78058a22bd3bc740679a05caa2 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 2 Oct 2025 17:59:53 +0300 Subject: [PATCH 42/50] Updated timeouts --- tests/test_asyncio/test_multidb/test_client.py | 12 +++++++++--- tests/test_multidb/test_client.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index a9252fca54..230a01d64d 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -204,6 +204,12 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( True, True, True, + True, + True, + True, + True, + True, + True, ] with ( @@ -218,14 +224,14 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.auto_fallback_interval = 0.5 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set("key", "value") == "OK1" - await asyncio.sleep(0.13) + await asyncio.sleep(0.15) assert await client.set("key", "value") == "OK2" - await asyncio.sleep(0.22) + await asyncio.sleep(0.5) assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 7d2c363136..eb8f5b374a 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -200,6 +200,12 @@ def test_execute_command_auto_fallback_to_highest_weight_db( True, True, True, + True, + True, + True, + True, + True, + True, ] with ( @@ -214,14 +220,14 @@ def test_execute_command_auto_fallback_to_highest_weight_db( mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.auto_fallback_interval = 0.5 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.13) + sleep(0.15) assert client.set("key", "value") == "OK2" - sleep(0.22) + sleep(0.5) assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( From cdef58b3b7503777de5fb3962ed1970b76ccef93 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 3 Oct 2025 10:22:21 +0300 Subject: [PATCH 43/50] Increased timeout --- tests/test_multidb/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index eb8f5b374a..5ea2193895 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -225,7 +225,7 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.15) + sleep(0.18) assert client.set("key", "value") == "OK2" sleep(0.5) assert client.set("key", "value") == "OK1" From 77f7e4f40fc4580d2f43bf641d45cccf6061505c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 11:31:28 +0300 Subject: [PATCH 44/50] Refactored tests --- redis/multidb/client.py | 2 +- redis/multidb/config.py | 4 +- .../test_asyncio/test_multidb/test_client.py | 55 +++++++++---------- tests/test_multidb/conftest.py | 2 +- tests/test_multidb/test_client.py | 51 +++++++++-------- tests/test_scenario/conftest.py | 2 +- 6 files changed, 55 insertions(+), 61 deletions(-) diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 485174fc03..02c3516eee 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -34,7 +34,7 @@ def __init__(self, config: MultiDbConfig): self._health_check_interval = config.health_check_interval self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( - config.health_check_probes, config.health_check_delay + config.health_check_probes, config.health_check_probes_delay ) self._failure_detectors = config.default_failure_detectors() diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 4586263748..8ce960c2dc 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -100,7 +100,7 @@ class MultiDbConfig: health_checks: Optional list of additional health checks performed on databases. health_check_interval: Time interval for executing health checks. health_check_probes: Number of attempts to evaluate the health of a database. - health_check_delay: Delay between health check attempts. + health_check_probes_delay: Delay between health check attempts. health_check_policy: Policy for determining database health based on health checks. failover_strategy: Optional strategy for handling database failover scenarios. failover_attempts: Number of retries allowed for failover operations. @@ -138,7 +138,7 @@ class MultiDbConfig: health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES - health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_probes_delay: float = DEFAULT_HEALTH_CHECK_DELAY health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY failover_strategy: Optional[FailoverStrategy] = None failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 230a01d64d..d6565865db 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -188,29 +188,24 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - mock_hc.check_health.side_effect = [ - True, - True, - True, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - ] + db1_counter = 0 + error_event = asyncio.Event() + check = False + + async def mock_check_health(database): + nonlocal db1_counter, check + + if database == mock_db1 and not check: + db1_counter += 1 + + if db1_counter > 1: + error_event.set() + check = True + return False + + return True + + mock_hc.check_health.side_effect = mock_check_health with ( patch.object(mock_multi_db_config, "databases", return_value=databases), @@ -224,15 +219,15 @@ async def test_execute_command_auto_fallback_to_highest_weight_db( mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.5 + mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() - client = MultiDBClient(mock_multi_db_config) - assert await client.set("key", "value") == "OK1" - await asyncio.sleep(0.15) - assert await client.set("key", "value") == "OK2" - await asyncio.sleep(0.5) - assert await client.set("key", "value") == "OK1" + async with MultiDBClient(mock_multi_db_config) as client: + assert await client.set("key", "value") == "OK1" + await error_event.wait() + assert await client.set("key", "value") == "OK2" + await asyncio.sleep(0.2) + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index f6ee6d3ec4..76990aecaa 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -111,7 +111,7 @@ def mock_multi_db_config(request, mock_fd, mock_fs, mock_hc, mock_ed) -> MultiDb databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], health_check_interval=hc_interval, - health_check_delay=0.05, + health_check_probes_delay=0.05, health_check_policy=health_check_policy, health_check_probes=health_check_probes, failover_strategy=mock_fs, diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 5ea2193895..1f1776351f 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -1,10 +1,13 @@ +import threading from time import sleep from unittest.mock import patch, Mock import pybreaker import pytest +from redis.backoff import NoBackoff from redis.event import EventDispatcher, OnCommandsFailEvent +from redis.exceptions import ConnectionError from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient @@ -12,6 +15,7 @@ from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list @@ -184,29 +188,24 @@ def test_execute_command_auto_fallback_to_highest_weight_db( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - mock_hc.check_health.side_effect = [ - True, - True, - True, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - ] + db1_counter = 0 + error_event = threading.Event() + check = False + + def mock_check_health(database): + nonlocal db1_counter, check + + if database == mock_db1 and not check: + db1_counter += 1 + + if db1_counter > 1: + error_event.set() + check = True + return False + + return True + + mock_hc.check_health.side_effect = mock_check_health with ( patch.object(mock_multi_db_config, "databases", return_value=databases), @@ -220,14 +219,14 @@ def test_execute_command_auto_fallback_to_highest_weight_db( mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.auto_fallback_interval = 0.5 + mock_multi_db_config.auto_fallback_interval = 0.2 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.18) + error_event.wait(timeout=0.5) assert client.set("key", "value") == "OK2" - sleep(0.5) + sleep(0.2) assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 5f568aa84e..409f3088ca 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -140,7 +140,7 @@ def r_multi_db( health_check_probes=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, - health_check_delay=health_check_delay, + health_check_probes_delay=health_check_delay, ) return MultiDBClient(config), listener, endpoint_config From 1f3e5c6f5554d0a5161526e4f64ece34332ce0f8 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 11:33:40 +0300 Subject: [PATCH 45/50] Codestyle changes --- tests/test_multidb/test_client.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 1f1776351f..3434e99c3a 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -5,9 +5,7 @@ import pybreaker import pytest -from redis.backoff import NoBackoff from redis.event import EventDispatcher, OnCommandsFailEvent -from redis.exceptions import ConnectionError from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient @@ -15,7 +13,6 @@ from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck -from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list From 0036b5bb789bd2a2999d19e68746da35bc814d8b Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 7 Oct 2025 11:36:40 +0300 Subject: [PATCH 46/50] Added documentation for Active-Active (#3753) * Added Active-Active documentation page * Added documentation for Active-Active * Refactored docs * Refactored pipeline and transaction section * Updated docs * Extended list of words * Re-write documentation * Fixed spelling * Update docs/multi_database.rst Co-authored-by: Elena Kolevska * Apply suggested comments * Fixed spelling * Update docs/multi_database.rst Co-authored-by: Elena Kolevska * Update docs/multi_database.rst Co-authored-by: Elena Kolevska * Update docs/multi_database.rst Co-authored-by: Elena Kolevska --------- Co-authored-by: Elena Kolevska --- .github/wordlist.txt | 19 ++ docs/multi_database.rst | 521 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 540 insertions(+) create mode 100644 docs/multi_database.rst diff --git a/.github/wordlist.txt b/.github/wordlist.txt index 150f96a624..0a69b9092a 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -1,7 +1,9 @@ APM ARGV BFCommands +balancer CacheImpl +cancelling CAS CFCommands CMSCommands @@ -10,10 +12,20 @@ ClusterNodes ClusterPipeline ClusterPubSub ConnectionPool +config CoreCommands +DatabaseConfig +DNS +EchoHealthCheck EVAL EVALSHA +failover +FQDN Grokzen's +Healthcheck +HealthCheckPolicies +healthcheck +healthchecks INCR IOError Instrumentations @@ -21,8 +33,11 @@ JSONCommands Jaeger Ludovico Magnocavallo +MultiDbConfig +MultiDBClient McCurdy NOSCRIPT +NoValidDatabaseException NUMPAT NUMPT NUMSUB @@ -43,6 +58,7 @@ RedisInstrumentor RedisJSON RedisTimeSeries SHA +SLA SearchCommands SentinelCommands SentinelConnectionPool @@ -52,6 +68,8 @@ SpanKind Specfiying StatusCode TCP +TemporaryUnavailableException +TLS TOPKCommands TimeSeriesCommands Uptrace @@ -91,6 +109,7 @@ json keyslot keyspace kwarg +kwargs linters localhost lua diff --git a/docs/multi_database.rst b/docs/multi_database.rst new file mode 100644 index 0000000000..6906b09e4d --- /dev/null +++ b/docs/multi_database.rst @@ -0,0 +1,521 @@ +Multi-database client (Active-Active) +===================================== + +The multi-database client lets you connect your application to multiple logical Redis databases at once +and operate them as a single, resilient endpoint. It continuously monitors health, detects failures, +and fails over to the next healthy database using a configurable strategy. When the previous primary +becomes healthy again, the client can automatically fall back to it. + +Key concepts +------------ + +- Database and weight: + Each database has a weight indicating its priority. The failover strategy chooses the highest-weight + healthy database as the active one. + +- Circuit breaker: + Each database is guarded by a circuit breaker with states CLOSED (healthy), OPEN (unhealthy), + and HALF_OPEN (probing). Health checks toggle these states to avoid hammering a downed database. + +- Health checks: + A set of checks determines whether a database is healthy in proactive manner. + By default, an "ECHO" check runs against the database (all cluster nodes must + pass for a cluster). You can add custom checks. A Redis Enterprise specific + "lag-aware" health check is also available. + +- Failure detector: + A detector observes command failures over a moving window (reactive monitoring). + You can specify an exact number of failures and failures rate to have more + fine-grain tuned configuration of triggering fail over based on organic traffic. + +- Failover strategy: + The default strategy is weight-based. It prefers the highest-weight healthy database. + +- Command retry: + Command execution supports retry with backoff. Low-level client retries are disabled and a global retry + setting is applied at the multi-database layer. + +- Auto fallback: + If configured with a positive interval, the client periodically attempts to fall back to a higher-priority + healthy database. + +- Events: + The client emits events like "active database changed" and "commands failed". Pub/Sub re-subscription + on database switch is handled automatically. + +Synchronous usage +----------------- + +Minimal example +^^^^^^^^^^^^^^^ + +.. code-block:: python + + from redis.multidb.client import MultiDBClient + from redis.multidb.config import MultiDbConfig, DatabaseConfig + + # Two databases. The first has higher weight -> preferred when healthy. + cfg = MultiDbConfig( + databases_config=[ + DatabaseConfig(from_url="redis://db-primary:6379/0", weight=1.0), + DatabaseConfig(from_url="redis://db-secondary:6379/0", weight=0.5), + ] + ) + + client = MultiDBClient(cfg) + + # First call triggers initialization and health checks. + client.set("key", "value") + print(client.get("key")) + + # Pipeline + with client.pipeline() as pipe: + pipe.set("a", 1) + pipe.incrby("a", 2) + values = pipe.execute() + print(values) + + # Transaction + def txn(pipe): + current = pipe.get("balance") + current = int(current or 0) + pipe.multi() # mark transaction + pipe.set("balance", current + 100) + + client.transaction(txn) + + # Pub/Sub usage - will automatically re-subscribe on database switch + pubsub = client.pubsub() + pubsub.subscribe("events") + + # In your loop: + message = pubsub.get_message(timeout=1.0) + if message: + print(message) + +Asyncio usage +------------- + +The asyncio API mirrors the synchronous one and provides async/await semantics. + +.. code-block:: python + + import asyncio + from redis.asyncio.multidb.client import MultiDBClient + from redis.asyncio.multidb.config import MultiDbConfig, DatabaseConfig + + async def main(): + cfg = MultiDbConfig( + databases_config=[ + DatabaseConfig(from_url="redis://db-primary:6379/0", weight=1.0), + DatabaseConfig(from_url="redis://db-secondary:6379/0", weight=0.5), + ] + ) + + # Context-manager approach for graceful client termination when exits. + # client = MultiDBClient(cfg) could be used instead + async with MultiDBClient(cfg) as client: + await client.set("key", "value") + print(await client.get("key")) + + # Pipeline + async with client.pipeline() as pipe: + pipe.set("a", 1) + pipe.incrby("a", 2) + values = await pipe.execute() + print(values) + + # Transaction + async def txn(pipe): + current = await pipe.get("balance") + current = int(current or 0) + await pipe.multi() + await pipe.set("balance", current + 100) + + await client.transaction(txn) + + # Pub/Sub + pubsub = client.pubsub() + await pubsub.subscribe("events") + message = await pubsub.get_message(timeout=1.0) + if message: + print(message) + + asyncio.run(main()) + + +MultiDBClient +^^^^^^^^^^^^^ + +The client exposes the same API as the `Redis` or `RedisCluster` client, making it fully interchangeable and ensuring a seamless upgrade for your application. Additionally, it supports runtime reconfiguration, allowing you to add features such as health checks, failure detectors, or even new databases without restarting. + +Configuration +------------- + +MultiDbConfig +^^^^^^^^^^^^^ + +.. code-block:: python + + from redis.multidb.config import ( + MultiDbConfig, DatabaseConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_GRACE_PERIOD + ) + from redis.retry import Retry + from redis.backoff import ExponentialWithJitterBackoff + + cfg = MultiDbConfig( + databases_config=[ + # Construct via URL + DatabaseConfig( + from_url="redis://db-a:6379/0", + weight=1.0, + # Optional: use a custom circuit breaker grace period + grace_period=DEFAULT_GRACE_PERIOD, + # Optional: Redis Enterprise cluster FQDN for REST health checks + health_check_url="https://cluster.example.com", + # Optional: Underlying Redis client related configuration + client_kwargs={"socket_timeout": 5} + ), + # Or construct via ConnectionPool + # DatabaseConfig(from_pool=my_pool, weight=1.0), + ], + + # Global command retry policy (applied at multi-db layer) + command_retry=Retry( + retries=3, + backoff=ExponentialWithJitterBackoff(base=1, cap=10), + ), + + # Health checks + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL # seconds + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY # seconds + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY, + + # Failure detector + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD + failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW # seconds + + # Failover behavior + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY # seconds + ) + +Notes: + +- Low-level client retries are disabled automatically per database. The multi-database layer handles retries. +- For clusters, health checks validate all nodes. + +DatabaseConfig +^^^^^^^^^^^^^^ + +Each database needs a `DatabaseConfig` that specifies how to connect. + +Method 1: Using client_kwargs (most flexible) +~~~~~~~~~~~~~~~~~~~~~ +There's an underlying instance of `Redis` or `RedisCluster` client for each database, +so you can pass all the arguments related to them via `client_kwargs` argument: + +.. code:: python + database_config = DatabaseConfig( + weight=1.0, + client_kwargs={ + 'host': 'localhost', + 'port': 6379, + 'username': "username", + 'password': "password", + } + ) + +Method 2: Using Redis URL +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + database_config1 = DatabaseConfig( + weight=1.0, + from_url="redis://host1:port1", + client_kwargs={ + 'username': "username", + 'password': "password", + } + ) + +Method 3: Using Custom Connection Pool +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + database_config2 = DatabaseConfig( + weight=0.9, + from_pool=connection_pool, + ) + +**Important**: Don't pass `Retry` objects in `client_kwargs`. `MultiDBClient` +handles all retries at the top level through the `command_retry` configuration. + +Health Monitoring +----------------- +The `MultiDBClient` uses two complementary mechanisms to ensure database availability: +- Health Checks (Proactive Monitoring) +- Failure Detection (Reactive Monitoring) + + +Health Checks (Proactive Monitoring) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +These checks run continuously in the background at configured intervals to proactively +detect database issues. They run in the background with a given interval and +configuration defined in the `MultiDBConfig` class. + +To avoid false positives, you can configure amount of health check probes and also +define one of the health check policies to evaluate probes result. + +**HealthCheckPolicies.HEALTHY_ALL** - (default) All probes should be successful. +**HealthCheckPolicies.HEALTHY_MAJORITY** - Majority of probes should be successful. +**HealthCheckPolicies.HEALTHY_ANY** - Any of probes should be successful. + +EchoHealthCheck (default) +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The default health check sends the [ECHO](https://redis.io/docs/latest/commands/echo/) command +to the database (and to all nodes for clusters). + +Lag-Aware Healthcheck (Redis Enterprise Only) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This is a special type of healthcheck available for Redis Software and Redis Cloud +that utilizes a REST API endpoint to obtain information about the synchronization +lag between a given database and all other databases in an Active-Active setup. + +To use this healthcheck, first you need to adjust your `DatabaseConfig` +to expose `health_check_url` used by your deployment. By default, your +Cluster FQDN should be used as URL, unless you have some kind of +reverse proxy behind an actual REST API endpoint. + +.. code-block:: python + + from redis.multidb.client import MultiDBClient + from redis.multidb.config import MultiDbConfig, DatabaseConfig + from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck + from redis.retry import Retry + from redis.backoff import ExponentialWithJitterBackoff + + cfg = MultiDbConfig( + databases_config=[ + DatabaseConfig( + from_url="redis://db-primary:6379/0", + weight=1.0, + health_check_url="https://cluster.example.com", # optional for LagAware + ), + DatabaseConfig( + from_url="redis://db-secondary:6379/0", + weight=0.5, + health_check_url="https://cluster.example.com", + ), + ], + # Add custom checks (in addition to default EchoHealthCheck) + health_checks=[ + # Redis Enterprise REST-based lag-aware check + LagAwareHealthCheck( + # Customize REST port, lag tolerance, TLS, etc. + rest_api_port=9443, + lag_aware_tolerance=100, # ms + verify_tls=True, + # auth_basic=("user", "pass"), + # ca_file="/path/ca.pem", + # client_cert_file="/path/cert.pem", + # client_key_file="/path/key.pem", + ), + ], + ) + + client = MultiDBClient(cfg) + + +**Custom Health Checks** +~~~~~~~~~~~~~~~~~~~~~ +You can add custom health checks for specific requirements: + +.. code-block:: python + + from redis.multidb.healthcheck import AbstractHealthCheck + from redis.retry import Retry + from redis.utils import dummy_fail + class PingHealthCheck(AbstractHealthCheck): + def __init__(self, retry: Retry): + super().__init__(retry=retry) + def check_health(self, database) -> bool: + return self._retry.call_with_retry( + lambda: self._returns_pong(database), + lambda _: dummy_fail() + ) + def _returns_pong(self, database) -> bool: + expected_message = ["PONG", b"PONG"] + actual_message = database.client.execute_command("PING") + return actual_message in expected_message + + +Failure Detection (Reactive Monitoring) +----------------- + +The failure detector monitor actual command failures and marks databases as unhealthy +when failures count and failure rate exceed thresholds within a sliding time window +of a few seconds. This catches issues that proactive health checks might miss during +real traffic. You can extend the list of failure detectors by providing your own +implementation, configuration defined in the `MultiDBConfig` class. + +By default failure detector is configured for 1000 failures and 10% failure rate +threshold within a 2 seconds sliding window, this could be adjusted regarding +your application specifics and traffic. + +.. code-block:: python + + from redis.multidb.config import MultiDbConfig, DatabaseConfig + from redis.multidb.client import MultiDBClient + + cfg = MultiDbConfig( + databases_config=[ + DatabaseConfig(from_url="redis://db-a:6379/0", weight=1.0), + DatabaseConfig(from_url="redis://db-b:6379/0", weight=0.5), + ], + # Default detector also created from config values + ) + + client = MultiDBClient(cfg) + + # Add an additional detector, optionally limited to specific exception types: + client.add_failure_detector( + CustomFailureDetector() + ) + +Failover and auto fallback +-------------------------- + +Weight-based failover chooses the highest-weight database whose circuit is CLOSED. If no database is +healthy it returns `TemporaryUnavailableException`. This exception indicates that application can +still send requests for some time (depends on configuration (`failover_attempts` * `failover_delay`) +120 seconds by default) until `NoValidDatabaseException` will be thrown. + +To enable periodic fallback to a higher-priority healthy database, set `auto_fallback_interval` (seconds): + +.. code-block:: python + + from redis.multidb.config import MultiDbConfig, DatabaseConfig + + cfg = MultiDbConfig( + databases_config=[ + DatabaseConfig(from_url="redis://db-primary:6379/0", weight=1.0), + DatabaseConfig(from_url="redis://db-secondary:6379/0", weight=0.5), + ], + # Try to fallback to higher-weight healthy database every 30 seconds + auto_fallback_interval=30.0, + ) + client = MultiDBClient(cfg) + +Managing databases at runtime +----------------------------- + +You can manually add/remove databases, update weights, and promote a database if it’s healthy. + +.. code-block:: python + + from redis.multidb.client import MultiDBClient + from redis.multidb.config import MultiDbConfig, DatabaseConfig + from redis.multidb.database import Database + from redis.multidb.circuit import PBCircuitBreakerAdapter + import pybreaker + from redis import Redis + + cfg = MultiDbConfig( + databases_config=[DatabaseConfig(from_url="redis://db-a:6379/0", weight=1.0)] + ) + client = MultiDBClient(cfg) + + # Add a database programmatically + other = Database( + client=Redis.from_url("redis://db-b:6379/0"), + circuit=PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5.0)), + weight=0.5, + health_check_url=None, + ) + client.add_database(other) + + # Update weight; if it becomes the highest and healthy, it may become active + client.update_database_weight(other, 0.9) + + # Promote a specific healthy database to active + client.set_active_database(other) + + # Remove a database + client.remove_database(other) + +Pub/Sub and re-subscription +-------------------------- + +The MultiDBClient offers Pub/Sub functionality with automatic re-subscription +to channels during failover events. For optimal failover handling, +both publishers and subscribers should use MultiDBClient instances. + +1. **Subscriber failover**: Automatically reconnects to an alternative database +and re-subscribes to the same channels +2. **Publisher failover**: Seamlessly switches to an alternative database and +continues publishing to the same channels +**Note**: Message loss may occur if failover events happen in reverse order +(publisher fails before subscriber). + +.. code-block:: python + + pubsub = client.pubsub() + pubsub.subscribe("news", "alerts") + # If failover happens here, subscriptions are re-established on the new active DB. + msg = pubsub.get_message(timeout=1.0) + if msg: + print(msg) + +Pipelines and transactions +-------------------------- + +Pipelines and transactions are executed against the active database at execution time. The client ensures +the active database is healthy and up-to-date before running the stack. + +.. code-block:: python + + with client.pipeline() as pipe: + pipe.set("x", 1) + pipe.incr("x") + results = pipe.execute() + + def txn(pipe): + pipe.multi() + pipe.set("y", "42") + + client.transaction(txn) + +Best practices +-------------- + +- Assign the highest weight to your primary database and lower weights to replicas or disaster recovery sites. +- Keep `health_check_interval` short enough to promptly detect failures but avoid excessive load. +- Tune `command_retry` and failover attempts to your SLA and workload profile. +- Use `auto_fallback_interval` if you want the client to fail over back to your primary automatically. +- Handle `TemporaryUnavailableException` to be able to recover before giving up. In the meantime, you +can switch the data source (e.g. cache). `NoValidDatabaseException` indicates that there are no healthy +databases to operate. + +Troubleshooting +--------------- + +- NoValidDatabaseException: + Indicates no healthy database is available. Check circuit breaker states and health checks. + +- TemporaryUnavailableException + Indicates that currently there are no healthy databases, but you can still send requests until + `NoValidDatabaseException` is thrown. Probe interval is configured with `failure_attemtps` + +- Health checks always failing: + Verify connectivity and, for clusters, that all nodes are reachable. For `LagAwareHealthCheck`, + ensure `health_check_url` points to your Redis Enterprise endpoint and authentication/TLS options + are configured properly. + +- Pub/Sub not receiving messages after failover: + Ensure you are using the client’s Pub/Sub helper. The client re-subscribes automatically on switch. \ No newline at end of file From f7c58f1ea4d043545fac6c9cb2b738dc8c4f2835 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 13:40:48 +0300 Subject: [PATCH 47/50] Refactor unstable tests --- .../test_asyncio/test_multidb/test_client.py | 106 ++++++--- .../test_multidb/test_pipeline.py | 209 ++++++++++++----- tests/test_multidb/test_client.py | 108 ++++++--- tests/test_multidb/test_pipeline.py | 221 +++++++++++++----- 4 files changed, 467 insertions(+), 177 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index d6565865db..6b4570e9bc 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -113,7 +113,7 @@ async def test_execute_command_against_correct_db_and_closed_circuit( indirect=True, ) async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -129,46 +129,98 @@ async def test_execute_command_against_correct_db_on_background_health_check_det databases = create_weighted_list(mock_db, mock_db1, mock_db2) + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = asyncio.Event() + db2_became_unhealthy = asyncio.Event() + db_became_unhealthy = asyncio.Event() + counter_lock = asyncio.Lock() + + async def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + async with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "OK", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "OK1", - "error", - "error", - "healthcheck", - "OK1", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "OK2", - "error", - "error", - ] + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) assert await client.set("key", "value") == "OK1" - await asyncio.sleep(0.15) + + # Wait for mock_db1 to become unhealthy + assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + await asyncio.sleep(0.01) + assert await client.set("key", "value") == "OK2" - await asyncio.sleep(0.1) + + # Wait for mock_db2 to become unhealthy + assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db2 to become unhealthy" + await asyncio.sleep(0.01) + assert await client.set("key", "value") == "OK" - await asyncio.sleep(0.1) + + # Wait for mock_db to become unhealthy + assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db to become unhealthy" + await asyncio.sleep(0.01) + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 48990bc62a..9a3ba63d0e 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -142,33 +142,74 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = asyncio.Event() + db2_became_unhealthy = asyncio.Event() + db_became_unhealthy = asyncio.Event() + counter_lock = asyncio.Lock() + + async def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + async with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "error", - "error", - "healthcheck", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "error", - "error", - ] - pipe = mock_pipe() pipe.execute.return_value = ["OK", "value"] mock_db.client.pipeline.return_value = pipe @@ -190,30 +231,28 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de pipe.set("key1", "value") pipe.get("key1") + # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) assert await pipe.execute() == ["OK1", "value"] - await asyncio.sleep(0.15) - - async with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Wait for mock_db1 to become unhealthy + assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + await asyncio.sleep(0.01) + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) assert await pipe.execute() == ["OK2", "value"] - await asyncio.sleep(0.1) - - async with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Wait for mock_db2 to become unhealthy + assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db2 to become unhealthy" + await asyncio.sleep(0.01) + # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) assert await pipe.execute() == ["OK", "value"] - await asyncio.sleep(0.1) - - async with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Wait for mock_db to become unhealthy + assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db to become unhealthy" + await asyncio.sleep(0.01) + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert await pipe.execute() == ["OK1", "value"] @@ -320,7 +359,7 @@ async def callback(pipe: Pipeline): indirect=True, ) async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -336,33 +375,73 @@ async def test_execute_transaction_against_correct_db_on_background_health_check databases = create_weighted_list(mock_db, mock_db1, mock_db2) + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = asyncio.Event() + db2_became_unhealthy = asyncio.Event() + db_became_unhealthy = asyncio.Event() + counter_lock = asyncio.Lock() + + async def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + async with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "error", - "error", - "healthcheck", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "error", - "error", - ] - mock_db.client.transaction.return_value = ["OK", "value"] mock_db1.client.transaction.return_value = ["OK1", "value"] mock_db2.client.transaction.return_value = ["OK2", "value"] @@ -377,9 +456,21 @@ async def callback(pipe: Pipeline): pipe.get("key1") assert await client.transaction(callback) == ["OK1", "value"] - await asyncio.sleep(0.15) + + # Wait for mock_db1 to become unhealthy + assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + await asyncio.sleep(0.01) + assert await client.transaction(callback) == ["OK2", "value"] - await asyncio.sleep(0.1) + + # Wait for mock_db2 to become unhealthy + assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + await asyncio.sleep(0.01) + assert await client.transaction(callback) == ["OK", "value"] - await asyncio.sleep(0.1) + + # Wait for mock_db to become unhealthy + assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + await asyncio.sleep(0.01) + assert await client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 3434e99c3a..ea7b5fd2b4 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -111,7 +111,7 @@ def test_execute_command_against_correct_db_and_closed_circuit( indirect=True, ) def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -127,46 +127,98 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = threading.Event() + db2_became_unhealthy = threading.Event() + db_became_unhealthy = threading.Event() + counter_lock = threading.Lock() + + def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "OK", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "OK1", - "error", - "error", - "healthcheck", - "OK1", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "OK2", - "error", - "error", - ] - mock_multi_db_config.health_check_interval = 0.2 + mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" client = MultiDBClient(mock_multi_db_config) assert client.set("key", "value") == "OK1" - sleep(0.3) + + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + sleep(0.01) + assert client.set("key", "value") == "OK2" - sleep(0.2) + + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + sleep(0.01) + assert client.set("key", "value") == "OK" - sleep(0.2) + + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + sleep(0.01) + assert client.set("key", "value") == "OK1" @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index c3a494dd95..b7bfab59a1 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -1,3 +1,4 @@ +import threading from time import sleep from unittest.mock import patch, Mock @@ -142,33 +143,73 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = threading.Event() + db2_became_unhealthy = threading.Event() + db_became_unhealthy = threading.Event() + counter_lock = threading.Lock() + + def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "error", - "error", - "healthcheck", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "error", - "error", - ] - pipe = mock_pipe() pipe.execute.return_value = ["OK", "value"] mock_db.client.pipeline.return_value = pipe @@ -190,31 +231,29 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin pipe.set("key1", "value") pipe.get("key1") - assert pipe.execute() == ["OK1", "value"] - - sleep(0.15) - - with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") - - assert pipe.execute() == ["OK2", "value"] + # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) + assert pipe.execute() == ["OK1", "value"] - sleep(0.1) + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + sleep(0.01) - with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) + assert pipe.execute() == ["OK2", "value"] - assert pipe.execute() == ["OK", "value"] + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + sleep(0.01) - sleep(0.1) + # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) + assert pipe.execute() == ["OK", "value"] - with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + sleep(0.01) - assert pipe.execute() == ["OK1", "value"] + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) + assert pipe.execute() == ["OK1", "value"] class TestTransaction: @@ -317,7 +356,7 @@ def callback(pipe: Pipeline): indirect=True, ) def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -333,33 +372,73 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) + # Track health check runs across all databases + health_check_run = 0 + + # Create events for each failover scenario + db1_became_unhealthy = threading.Event() + db2_became_unhealthy = threading.Event() + db_became_unhealthy = threading.Event() + counter_lock = threading.Lock() + + def mock_check_health(database): + nonlocal health_check_run + + # Increment run counter for each health check call + with counter_lock: + health_check_run += 1 + current_run = health_check_run + + # Run 1 (health_check_run 1-3): All databases healthy + if current_run <= 3: + return True + + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy + elif current_run <= 6: + if database == mock_db1: + if current_run == 6: + db1_became_unhealthy.set() + return False + + # Signal that db1 has become unhealthy after all 3 checks + if current_run == 6: + db1_became_unhealthy.set() + return True + + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy + elif current_run <= 9: + if database == mock_db1 or database == mock_db2: + if current_run == 9: + db2_became_unhealthy.set() + return False + + # Signal that db2 has become unhealthy after all 3 checks + if current_run == 9: + db2_became_unhealthy.set() + return True + + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy + else: + if database == mock_db: + if current_run >= 12: + db_became_unhealthy.set() + return False + + # Signal that db has become unhealthy after all 3 checks + if current_run >= 12: + db_became_unhealthy.set() + return True + + mock_hc.check_health.side_effect = mock_check_health + with ( patch.object(mock_multi_db_config, "databases", return_value=databases), patch.object( mock_multi_db_config, "default_health_checks", - return_value=[EchoHealthCheck()], + return_value=[mock_hc], ), ): - mock_db.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "healthcheck", - "error", - ] - mock_db1.client.execute_command.side_effect = [ - "healthcheck", - "error", - "error", - "healthcheck", - ] - mock_db2.client.execute_command.side_effect = [ - "healthcheck", - "healthcheck", - "error", - "error", - ] - mock_db.client.transaction.return_value = ["OK", "value"] mock_db1.client.transaction.return_value = ["OK1", "value"] mock_db2.client.transaction.return_value = ["OK2", "value"] @@ -373,10 +452,26 @@ def callback(pipe: Pipeline): pipe.set("key1", "value1") pipe.get("key1") + # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) assert client.transaction(callback) == ["OK1", "value"] - sleep(0.15) + + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + sleep(0.01) + + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) assert client.transaction(callback) == ["OK2", "value"] - sleep(0.1) + + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + sleep(0.01) + + # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) assert client.transaction(callback) == ["OK", "value"] - sleep(0.1) + + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + sleep(0.01) + + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert client.transaction(callback) == ["OK1", "value"] From 8002aac75918fc92733083769987b28103eee9f4 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 13:43:00 +0300 Subject: [PATCH 48/50] Marked tests as non-clustered --- tests/test_asyncio/test_multidb/test_pipeline.py | 4 ++-- tests/test_multidb/test_config.py | 2 +- tests/test_multidb/test_pipeline.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 9a3ba63d0e..f559b87908 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -18,7 +18,7 @@ def mock_pipe() -> Pipeline: mock_pipe.__aexit__ = AsyncMock(return_value=None) return mock_pipe - +@pytest.mark.onlynoncluster class TestPipeline: @pytest.mark.asyncio @pytest.mark.parametrize( @@ -255,7 +255,7 @@ async def mock_check_health(database): # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert await pipe.execute() == ["OK1", "value"] - +@pytest.mark.onlynoncluster class TestTransaction: @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index ea81f71ac9..0e57dd7821 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -134,7 +134,7 @@ def test_overridden_config(self): assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval - +@pytest.mark.onlynoncluster class TestDatabaseConfig: def test_default_config(self): config = DatabaseConfig( diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index b7bfab59a1..1c456cd219 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -255,7 +255,7 @@ def mock_check_health(database): # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert pipe.execute() == ["OK1", "value"] - +@pytest.mark.onlynoncluster class TestTransaction: @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", From 1b89502b3c50bb0b6325aaa4acc839298e6c0424 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 13:45:08 +0300 Subject: [PATCH 49/50] Codestyle changes --- .../test_asyncio/test_multidb/test_client.py | 14 ++++-- .../test_multidb/test_pipeline.py | 28 +++++++---- tests/test_multidb/test_client.py | 14 ++++-- tests/test_multidb/test_config.py | 1 + tests/test_multidb/test_pipeline.py | 48 ++++++++++++------- 5 files changed, 71 insertions(+), 34 deletions(-) diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 6b4570e9bc..d3266b7109 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.asyncio.multidb.database import AsyncDatabase from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector -from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.multidb.healthcheck import HealthCheck from redis.event import EventDispatcher, AsyncOnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.exception import NoValidDatabaseException @@ -206,19 +206,25 @@ async def mock_check_health(database): assert await client.set("key", "value") == "OK1" # Wait for mock_db1 to become unhealthy - assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + assert await db1_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.set("key", "value") == "OK2" # Wait for mock_db2 to become unhealthy - assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db2 to become unhealthy" + assert await db2_became_unhealthy.wait(), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.set("key", "value") == "OK" # Wait for mock_db to become unhealthy - assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db to become unhealthy" + assert await db_became_unhealthy.wait(), ( + "Timeout waiting for mock_db to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.set("key", "value") == "OK1" diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index f559b87908..528f8e813b 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -7,7 +7,6 @@ from redis.asyncio.client import Pipeline from redis.asyncio.multidb.client import MultiDBClient from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy -from redis.asyncio.multidb.healthcheck import EchoHealthCheck from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from tests.test_asyncio.test_multidb.conftest import create_weighted_list @@ -18,6 +17,7 @@ def mock_pipe() -> Pipeline: mock_pipe.__aexit__ = AsyncMock(return_value=None) return mock_pipe + @pytest.mark.onlynoncluster class TestPipeline: @pytest.mark.asyncio @@ -142,7 +142,6 @@ async def test_execute_pipeline_against_correct_db_on_background_health_check_de databases = create_weighted_list(mock_db, mock_db1, mock_db2) - # Track health check runs across all databases health_check_run = 0 @@ -235,26 +234,33 @@ async def mock_check_health(database): assert await pipe.execute() == ["OK1", "value"] # Wait for mock_db1 to become unhealthy - assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + assert await db1_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) await asyncio.sleep(0.01) # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) assert await pipe.execute() == ["OK2", "value"] # Wait for mock_db2 to become unhealthy - assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db2 to become unhealthy" + assert await db2_became_unhealthy.wait(), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) await asyncio.sleep(0.01) # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) assert await pipe.execute() == ["OK", "value"] # Wait for mock_db to become unhealthy - assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db to become unhealthy" + assert await db_became_unhealthy.wait(), ( + "Timeout waiting for mock_db to become unhealthy" + ) await asyncio.sleep(0.01) # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert await pipe.execute() == ["OK1", "value"] + @pytest.mark.onlynoncluster class TestTransaction: @pytest.mark.asyncio @@ -458,19 +464,25 @@ async def callback(pipe: Pipeline): assert await client.transaction(callback) == ["OK1", "value"] # Wait for mock_db1 to become unhealthy - assert await db1_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + assert await db1_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.transaction(callback) == ["OK2", "value"] # Wait for mock_db2 to become unhealthy - assert await db2_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + assert await db2_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.transaction(callback) == ["OK", "value"] # Wait for mock_db to become unhealthy - assert await db_became_unhealthy.wait(), "Timeout waiting for mock_db1 to become unhealthy" + assert await db_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) await asyncio.sleep(0.01) assert await client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index ea7b5fd2b4..34ea710234 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -12,7 +12,7 @@ from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -204,19 +204,25 @@ def mock_check_health(database): assert client.set("key", "value") == "OK1" # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) sleep(0.01) assert client.set("key", "value") == "OK2" # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) sleep(0.01) assert client.set("key", "value") == "OK" # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) sleep(0.01) assert client.set("key", "value") == "OK1" diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 0e57dd7821..351f789971 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -134,6 +134,7 @@ def test_overridden_config(self): assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval + @pytest.mark.onlynoncluster class TestDatabaseConfig: def test_default_config(self): diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 1c456cd219..0055718d4f 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -11,7 +11,6 @@ from redis.multidb.failover import ( WeightBasedFailoverStrategy, ) -from redis.multidb.healthcheck import EchoHealthCheck from tests.test_multidb.conftest import create_weighted_list @@ -235,26 +234,33 @@ def mock_check_health(database): assert pipe.execute() == ["OK1", "value"] # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) sleep(0.01) # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) assert pipe.execute() == ["OK2", "value"] # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) sleep(0.01) # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) assert pipe.execute() == ["OK", "value"] # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) sleep(0.01) # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert pipe.execute() == ["OK1", "value"] + @pytest.mark.onlynoncluster class TestTransaction: @pytest.mark.parametrize( @@ -374,7 +380,7 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter # Track health check runs across all databases health_check_run = 0 - + # Create events for each failover scenario db1_became_unhealthy = threading.Event() db2_became_unhealthy = threading.Event() @@ -383,16 +389,16 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter def mock_check_health(database): nonlocal health_check_run - + # Increment run counter for each health check call with counter_lock: health_check_run += 1 current_run = health_check_run - + # Run 1 (health_check_run 1-3): All databases healthy if current_run <= 3: return True - + # Run 2 (health_check_run 4-6): mock_db1 unhealthy, others healthy elif current_run <= 6: if database == mock_db1: @@ -404,7 +410,7 @@ def mock_check_health(database): if current_run == 6: db1_became_unhealthy.set() return True - + # Run 3 (health_check_run 7-9): mock_db1 and mock_db2 unhealthy, mock_db healthy elif current_run <= 9: if database == mock_db1 or database == mock_db2: @@ -416,7 +422,7 @@ def mock_check_health(database): if current_run == 9: db2_became_unhealthy.set() return True - + # Run 4 (health_check_run 10-12): mock_db unhealthy, others healthy else: if database == mock_db: @@ -454,24 +460,30 @@ def callback(pipe: Pipeline): # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) assert client.transaction(callback) == ["OK1", "value"] - + # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db1 to become unhealthy" + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) sleep(0.01) - + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) assert client.transaction(callback) == ["OK2", "value"] - + # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db2 to become unhealthy" + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) sleep(0.01) # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) assert client.transaction(callback) == ["OK", "value"] - + # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), "Timeout waiting for mock_db to become unhealthy" + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) sleep(0.01) - + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) assert client.transaction(callback) == ["OK1", "value"] From 67683e270bb7ed53d36145384c7b66ba8fb282f3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 7 Oct 2025 15:05:43 +0300 Subject: [PATCH 50/50] Skipped tests in validating workflow --- .github/workflows/install_and_test.sh | 1 + tests/test_asyncio/test_multidb/test_client.py | 2 +- tests/test_multidb/test_client.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index 85cb07cb8a..a83cef4089 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -46,5 +46,6 @@ CLUSTER_URL="redis://localhost:16379/0" CLUSTER_SSL_URL="rediss://localhost:27379/0" pytest -m 'not onlynoncluster and not redismod and not ssl' \ --ignore=tests/test_scenario \ + --ignore=tests/test_asyncio/test_scenario \ --redis-url="${CLUSTER_URL}" \ --redis-ssl-url="${CLUSTER_SSL_URL}" diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index d3266b7109..e912a00466 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -284,7 +284,7 @@ async def mock_check_health(database): assert await client.set("key", "value") == "OK1" await error_event.wait() assert await client.set("key", "value") == "OK2" - await asyncio.sleep(0.2) + await asyncio.sleep(0.5) assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 34ea710234..cbc81b15ed 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -281,7 +281,7 @@ def mock_check_health(database): assert client.set("key", "value") == "OK1" error_event.wait(timeout=0.5) assert client.set("key", "value") == "OK2" - sleep(0.2) + sleep(0.5) assert client.set("key", "value") == "OK1" @pytest.mark.parametrize(