diff --git a/redis/client.py b/redis/client.py index 060fc29493..e22ca3d73d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,6 +1217,8 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + pubsub = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1230,8 +1232,9 @@ def run_in_thread( f"Shard Channel: '{s_channel}' has no handler registered" ) + pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - self, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler, sharded_pubsub=sharded_pubsub ) thread.start() return thread @@ -1246,12 +1249,14 @@ def __init__( exception_handler: Union[ Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None ] = None, + sharded_pubsub: bool = False, ): super().__init__() self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self.exception_handler = exception_handler + self.sharded_pubsub = sharded_pubsub self._running = threading.Event() def run(self) -> None: @@ -1262,7 +1267,10 @@ def run(self) -> None: sleep_time = self.sleep_time while self._running.is_set(): try: - pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + if not self.sharded_pubsub: + pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + else: + pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=sleep_time) except BaseException as e: if self.exception_handler is None: raise diff --git a/redis/cluster.py b/redis/cluster.py index dbcf5cc2b7..2fd4625e6b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3154,7 +3154,8 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._nodes_manager.update_moved_exception(error) + if type(error) == MovedError: + self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 367060bcc3..9159ec3599 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,14 +1,15 @@ import threading import socket -from typing import List, Any, Callable +from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler +from redis.client import PubSubWorkerThread from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases +from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -70,13 +71,8 @@ def raise_exception_on_failed_hc(error): # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db_found: - database.state = DBState.ACTIVE self.command_executor.active_database = database is_active_db_found = True - elif database.circuit.state == CBState.CLOSED and is_active_db_found: - database.state = DBState.PASSIVE - else: - database.state = DBState.DISCONNECTED if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') @@ -107,8 +103,6 @@ def set_active_database(self, database: AbstractDatabase) -> None: if database.circuit.state == CBState.CLOSED: highest_weighted_db, _ = self._databases.get_top_n(1)[0] - highest_weighted_db.state = DBState.PASSIVE - database.state = DBState.ACTIVE self.command_executor.active_database = database return @@ -130,9 +124,7 @@ def add_database(self, database: AbstractDatabase): def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: - new_database.state = DBState.ACTIVE self.command_executor.active_database = new_database - highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): """ @@ -142,7 +134,6 @@ def remove_database(self, database: Database): highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: - highest_weighted_db.state = DBState.ACTIVE self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: AbstractDatabase, weight: float): @@ -201,6 +192,17 @@ def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): return self.command_executor.execute_transaction(func, *watches, *options) + def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + self.initialize() + + return PubSub(self, **kwargs) + def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. @@ -221,7 +223,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: + except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError, ValueError) as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False @@ -311,3 +313,133 @@ def execute(self) -> List[Any]: return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + self._client = client + self._client.command_executor.pubsub(**kwargs) + + def __enter__(self) -> "PubSub": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.reset() + + def __del__(self) -> None: + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + self.reset() + except Exception: + pass + + def reset(self) -> None: + pass + + def close(self) -> None: + self.reset() + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return self._client.command_executor.execute_pubsub_method('psubscribe', *args, **kwargs) + + def punsubscribe(self, *args): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return self._client.command_executor.execute_pubsub_method('punsubscribe', *args) + + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return self._client.command_executor.execute_pubsub_method('subscribe', *args, **kwargs) + + def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return self._client.command_executor.execute_pubsub_method('unsubscribe', *args) + + def ssubscribe(self, *args, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + return self._client.command_executor.execute_pubsub_method('ssubscribe', *args, **kwargs) + + def sunsubscribe(self, *args): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + return self._client.command_executor.execute_pubsub_method('sunsubscribe', *args) + + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + def get_sharded_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available in a sharded channel, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_sharded_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + def run_in_thread( + self, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, + ) -> "PubSubWorkerThread": + return self._client.command_executor.execute_pubsub_run_in_thread( + sleep_time=sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=self, + sharded_pubsub=sharded_pubsub, + ) + diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 690ea49a5c..ab5e121c86 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union, Optional, Callable +from typing import List, Optional, Callable -from redis.client import Pipeline +from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState -from redis.multidb.event import RegisterCommandFailure +from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -34,7 +34,7 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Union[Database, None]: + def active_database(self) -> Optional[Database]: """Returns currently active database.""" pass @@ -44,6 +44,23 @@ def active_database(self, database: AbstractDatabase) -> None: """Sets currently active database.""" pass + @abstractmethod + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + @property @abstractmethod def failover_strategy(self) -> FailoverStrategy: @@ -103,7 +120,9 @@ def __init__( self._event_dispatcher = event_dispatcher self._auto_fallback_interval = auto_fallback_interval self._next_fallback_attempt: datetime - self._active_database: Union[Database, None] = None + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() @@ -128,8 +147,22 @@ def active_database(self) -> Optional[AbstractDatabase]: @active_database.setter def active_database(self, database: AbstractDatabase) -> None: + old_active = self._active_database self._active_database = database + if old_active is not None and old_active is not database: + self._event_dispatcher.dispatch( + ActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + @property def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy @@ -143,6 +176,7 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): + """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) @@ -170,6 +204,44 @@ def callback(): return self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return self._execute_with_failure_detection(callback) + + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """ + Executes given method on active pub/sub. + """ + def callback(): + method = getattr(self.active_pubsub, method_name) + return method(*args, **kwargs) + + return self._execute_with_failure_detection(callback, *args) + + def execute_pubsub_run_in_thread( + self, + pubsub, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, + ) -> "PubSubWorkerThread": + def callback(): + return self._active_pubsub.run_in_thread( + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=pubsub, + sharded_pubsub=sharded_pubsub + ) + + return self._execute_with_failure_detection(callback) + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): """ Execute a commands execution callback with failure detection. @@ -199,7 +271,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self._active_database = self._failover_strategy.database + self.active_database = self._failover_strategy.database self._schedule_next_fallback() def _schedule_next_fallback(self) -> None: @@ -210,9 +282,11 @@ def _schedule_next_fallback(self) -> None: def _setup_event_dispatcher(self): """ - Registers command failure event listener. + Registers necessary listeners. """ - event_listener = RegisterCommandFailure(self._failure_detectors) + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() self._event_dispatcher.register_listeners({ - OnCommandsFailEvent: [event_listener], + OnCommandsFailEvent: [failure_listener], + ActiveDatabaseChanged: [resubscribe_listener], }) \ No newline at end of file diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 15db52e909..1161ba936f 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -8,12 +8,6 @@ from redis.multidb.circuit import CircuitBreaker from redis.typing import Number - -class State(Enum): - ACTIVE = 0 - PASSIVE = 1 - DISCONNECTED = 2 - class AbstractDatabase(ABC): @property @abstractmethod @@ -39,18 +33,6 @@ def weight(self, weight: float): """Set the weight of this database in compare to others.""" pass - @property - @abstractmethod - def state(self) -> State: - """The state of the current database.""" - pass - - @state.setter - @abstractmethod - def state(self, state: State): - """Set the state of the current database.""" - pass - @property @abstractmethod def circuit(self) -> CircuitBreaker: @@ -70,8 +52,7 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float, - state: State = State.DISCONNECTED, + weight: float ): """ param: client: Client instance for communication with the database. @@ -83,7 +64,6 @@ def __init__( self._cb = circuit self._cb.database = self self._weight = weight - self._state = state @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -101,14 +81,6 @@ def weight(self) -> float: def weight(self, weight: float): self._weight = weight - @property - def state(self) -> State: - return self._state - - @state.setter - def state(self, state: State): - self._state = state - @property def circuit(self) -> CircuitBreaker: return self._cb diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 315802e812..2598bc4d06 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,58 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent +from redis.multidb.config import Databases +from redis.multidb.database import AbstractDatabase from redis.multidb.failure_detector import FailureDetector +class ActiveDatabaseChanged: + """ + Event fired when an active database has been changed. + """ + def __init__( + self, + old_database: AbstractDatabase, + new_database: AbstractDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AbstractDatabase: + return self._old_database + + @property + def new_database(self) -> AbstractDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): + """ + Re-subscribe currently active pub/sub to a new active database. + """ + def listen(self, event: ActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + new_pubsub.shard_channels = old_pubsub.shard_channels + new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + old_pubsub.close() class RegisterCommandFailure(EventListenerInterface): """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index a96b9cf815..cca220dc3f 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,4 +1,7 @@ from abc import abstractmethod, ABC + +import redis +from redis import Redis from redis.retry import Retry @@ -21,6 +24,7 @@ def __init__( retry: Retry, ) -> None: self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) @property def retry(self) -> Retry: @@ -50,8 +54,20 @@ def check_health(self, database) -> bool: def _returns_echoed_message(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] - actual_message = database.client.execute_command('ECHO', "healthcheck") - return actual_message in expected_message + + if isinstance(database.client, Redis): + actual_message = database.client.execute_command("ECHO" ,"healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = node.redis_connection.execute_command("ECHO" ,"healthcheck") + + if actual_message not in expected_message: + return False + + return True def _dummy_fail(self): pass \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index ad2057a118..f52b66e0a6 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -7,7 +7,7 @@ from redis.multidb.circuit import CircuitBreaker, State as CBState from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, State, Databases +from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -38,7 +38,6 @@ def mock_hc() -> HealthCheck: def mock_db(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -53,7 +52,6 @@ def mock_db(request) -> Database: def mock_db1(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -68,7 +66,6 @@ def mock_db1(request) -> Database: def mock_db2(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index cf3877957f..2503ed98fc 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import State as DBState, AbstractDatabase +from redis.multidb.database import AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -190,26 +190,14 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.15) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.22) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -244,10 +232,6 @@ def test_execute_command_throws_exception_on_failed_initialization( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 3 - assert mock_db.state == DBState.DISCONNECTED - assert mock_db1.state == DBState.DISCONNECTED - assert mock_db2.state == DBState.DISCONNECTED - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -318,9 +302,6 @@ def test_add_database_makes_new_database_active( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 2 - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - client.add_database(mock_db1) for hc in mock_multi_db_config.health_checks: @@ -328,10 +309,6 @@ def test_add_database_makes_new_database_active( assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -368,17 +345,10 @@ def test_remove_highest_weighted_database( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.remove_database(mock_db1) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -415,19 +385,11 @@ def test_update_database_weight_to_be_highest( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -563,17 +525,9 @@ def test_set_active_database( for hc in mock_multi_db_config.health_checks: assert hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.set_active_database(mock_db) assert client.set('key', 'value') == 'OK' - assert mock_db.state == DBState.ACTIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.PASSIVE - with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 9601638913..08bd8ab0c4 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,5 +1,5 @@ from redis.backoff import ExponentialBackoff -from redis.multidb.database import Database, State +from redis.multidb.database import Database from redis.multidb.healthcheck import EchoHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -14,7 +14,7 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 @@ -26,7 +26,7 @@ def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, moc """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 3 @@ -35,7 +35,7 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] mock_cb.state = CBState.HALF_OPEN hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 8ae7441e98..20f9557bdd 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -3,13 +3,23 @@ import pytest -from redis.backoff import NoBackoff +from redis import Redis +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_FAILURES_THRESHOLD +from redis.multidb.event import ActiveDatabaseChanged +from redis.multidb.healthcheck import EchoHealthCheck from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient +class CheckActiveDatabaseChangedListener(EventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + def listen(self, event: ActiveDatabaseChanged): + self.is_changed_flag = True def get_endpoint_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) @@ -33,13 +43,28 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest.fixture() -def r_multi_db(request) -> MultiDBClient: - endpoint_config = get_endpoint_config('re-active-active') +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(NoBackoff(), retries=3)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_checks = [EchoHealthCheck(Retry(ExponentialBackoff(cap=5, base=0.5), retries=3))] health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + ActiveDatabaseChanged: [listener], + }) db_configs = [] db_config = DatabaseConfig( @@ -64,12 +89,14 @@ def r_multi_db(request) -> MultiDBClient: ) db_configs.append(db_config1) - config = MultiDbConfig( + client_class=client_class, databases_config=db_configs, + health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, ) - return MultiDBClient(config) \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index a8afea4b18..c34b56f0c3 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,23 +1,20 @@ +import json import logging import threading from time import sleep import pytest -from redis.backoff import NoBackoff +from redis import Redis, RedisCluster from redis.client import Pipeline -from redis.exceptions import ConnectionError -from redis.retry import Retry -from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) -def trigger_network_failure_action(fault_injector_client, event: threading.Event = None): - endpoint_config = get_endpoint_config('re-active-active') +def trigger_network_failure_action(fault_injector_client, config, event: threading.Event = None): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 3, "cluster_index": 0} + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -37,26 +34,30 @@ class TestActiveActiveStandalone: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(3) + sleep(6) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) # Client initialized on the first command. r_multi_db.set('key', 'value') - current_active_db = r_multi_db.command_executor.active_database thread.start() # Execute commands before network failure @@ -64,49 +65,29 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert r_multi_db.get('key') == 'value' sleep(0.1) - # Active db has been changed. - assert current_active_db != r_multi_db.command_executor.active_database - - # Execute commands after network failure - for _ in range(3): + # Execute commands until database failover + while not listener.is_changed_flag: assert r_multi_db.get('key') == 'value' sleep(0.1) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} - ], - indirect=True - ) - def test_multi_db_client_throws_error_on_retry_exceed(self, r_multi_db, fault_injector_client): - event = threading.Event() - thread = threading.Thread( - target=trigger_network_failure_action, - daemon=True, - args=(fault_injector_client, event) - ) - thread.start() - - with pytest.raises(ConnectionError): - # Retries count > failure threshold, so a client gives up earlier. - while not event.is_set(): - assert r_multi_db.get('key') == 'value' - sleep(0.1) - - @pytest.mark.parametrize( - "r_multi_db", - [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) # Client initialized on first pipe execution. @@ -133,8 +114,8 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + while not listener.is_changed_flag: with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -148,16 +129,21 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) # Client initialized on first pipe execution. @@ -174,6 +160,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject # Execute pipeline before network failure while not event.is_set(): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -183,8 +170,9 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + while not listener.is_changed_flag: + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -197,16 +185,21 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) def callback(pipe: Pipeline): @@ -226,7 +219,103 @@ def callback(pipe: Pipeline): r_multi_db.transaction(callback) sleep(0.1) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + while not listener.is_changed_flag: r_multi_db.transaction(callback) - sleep(0.1) \ No newline at end of file + sleep(0.1) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,config,event) + ) + data = json.dumps({'message': 'test'}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.subscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.publish('test-channel', data) + sleep(0.1) + + # Execute pipeline until database failover + while not listener.is_changed_flag: + r_multi_db.publish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + assert messages_count > 5 + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,config,event) + ) + data = json.dumps({'message': 'test'}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.ssubscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread( + sleep_time=0.1, + daemon=True, + sharded_pubsub=True + ) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + # Execute pipeline until database failover + while not listener.is_changed_flag: + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + assert messages_count > 5 \ No newline at end of file