diff --git a/redis/_cache.py b/redis/_cache.py deleted file mode 100644 index 90288383d6..0000000000 --- a/redis/_cache.py +++ /dev/null @@ -1,385 +0,0 @@ -import copy -import random -import time -from abc import ABC, abstractmethod -from collections import OrderedDict, defaultdict -from enum import Enum -from typing import List, Sequence, Union - -from redis.typing import KeyT, ResponseT - - -class EvictionPolicy(Enum): - LRU = "lru" - LFU = "lfu" - RANDOM = "random" - - -DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU - -DEFAULT_DENY_LIST = [ - "BF.CARD", - "BF.DEBUG", - "BF.EXISTS", - "BF.INFO", - "BF.MEXISTS", - "BF.SCANDUMP", - "CF.COMPACT", - "CF.COUNT", - "CF.DEBUG", - "CF.EXISTS", - "CF.INFO", - "CF.MEXISTS", - "CF.SCANDUMP", - "CMS.INFO", - "CMS.QUERY", - "DUMP", - "EXPIRETIME", - "FT.AGGREGATE", - "FT.ALIASADD", - "FT.ALIASDEL", - "FT.ALIASUPDATE", - "FT.CURSOR", - "FT.EXPLAIN", - "FT.EXPLAINCLI", - "FT.GET", - "FT.INFO", - "FT.MGET", - "FT.PROFILE", - "FT.SEARCH", - "FT.SPELLCHECK", - "FT.SUGGET", - "FT.SUGLEN", - "FT.SYNDUMP", - "FT.TAGVALS", - "FT._ALIASADDIFNX", - "FT._ALIASDELIFX", - "HRANDFIELD", - "JSON.DEBUG", - "PEXPIRETIME", - "PFCOUNT", - "PTTL", - "SRANDMEMBER", - "TDIGEST.BYRANK", - "TDIGEST.BYREVRANK", - "TDIGEST.CDF", - "TDIGEST.INFO", - "TDIGEST.MAX", - "TDIGEST.MIN", - "TDIGEST.QUANTILE", - "TDIGEST.RANK", - "TDIGEST.REVRANK", - "TDIGEST.TRIMMED_MEAN", - "TOPK.INFO", - "TOPK.LIST", - "TOPK.QUERY", - "TOUCH", - "TTL", -] - -DEFAULT_ALLOW_LIST = [ - "BITCOUNT", - "BITFIELD_RO", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUSBYMEMBER_RO", - "GEORADIUS_RO", - "GEOSEARCH", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "JSON.ARRINDEX", - "JSON.ARRLEN", - "JSON.GET", - "JSON.MGET", - "JSON.OBJKEYS", - "JSON.OBJLEN", - "JSON.RESP", - "JSON.STRLEN", - "JSON.TYPE", - "LCS", - "LINDEX", - "LLEN", - "LPOS", - "LRANGE", - "MGET", - "SCARD", - "SDIFF", - "SINTER", - "SINTERCARD", - "SISMEMBER", - "SMEMBERS", - "SMISMEMBER", - "SORT_RO", - "STRLEN", - "SUBSTR", - "SUNION", - "TS.GET", - "TS.INFO", - "TS.RANGE", - "TS.REVRANGE", - "TYPE", - "XLEN", - "XPENDING", - "XRANGE", - "XREAD", - "XREVRANGE", - "ZCARD", - "ZCOUNT", - "ZDIFF", - "ZINTER", - "ZINTERCARD", - "ZLEXCOUNT", - "ZMSCORE", - "ZRANGE", - "ZRANGEBYLEX", - "ZRANGEBYSCORE", - "ZRANK", - "ZREVRANGE", - "ZREVRANGEBYLEX", - "ZREVRANGEBYSCORE", - "ZREVRANK", - "ZSCORE", - "ZUNION", -] - -_RESPONSE = "response" -_KEYS = "keys" -_CTIME = "ctime" -_ACCESS_COUNT = "access_count" - - -class AbstractCache(ABC): - """ - An abstract base class for client caching implementations. - If you want to implement your own cache you must support these methods. - """ - - @abstractmethod - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - pass - - @abstractmethod - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - pass - - @abstractmethod - def delete_command(self, command: Union[str, Sequence[str]]): - pass - - @abstractmethod - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - pass - - @abstractmethod - def flush(self): - pass - - @abstractmethod - def invalidate_key(self, key: KeyT): - pass - - -class _LocalCache(AbstractCache): - """ - A caching mechanism for storing redis commands and their responses. - - Args: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. - - Attributes: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy used for cache management. - cache (OrderedDict): The ordered dictionary to store commands and their metadata. - key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. - commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa - """ - - def __init__( - self, - max_size: int = 10000, - ttl: int = 0, - eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - ): - self.max_size = max_size - self.ttl = ttl - self.eviction_policy = eviction_policy - self.cache = OrderedDict() - self.key_commands_map = defaultdict(set) - self.commands_ttl_list = [] - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - """ - Set a redis command and its response in the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - response (ResponseT): The response associated with the command. - keys_in_command (List[KeyT]): The list of keys used in the command. - """ - if len(self.cache) >= self.max_size: - self._evict() - self.cache[command] = { - _RESPONSE: response, - _KEYS: keys_in_command, - _CTIME: time.monotonic(), - _ACCESS_COUNT: 0, # Used only for LFU - } - self._update_key_commands_map(keys_in_command, command) - self.commands_ttl_list.append(command) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - """ - Get the response for a redis command from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa - """ - if command in self.cache: - if self._is_expired(command): - self.delete_command(command) - return - self._update_access(command) - return copy.deepcopy(self.cache[command]["response"]) - - def delete_command(self, command: Union[str, Sequence[str]]): - """ - Delete a redis command and its metadata from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command to be deleted. - """ - if command in self.cache: - keys_in_command = self.cache[command].get("keys") - self._del_key_commands_map(keys_in_command, command) - self.commands_ttl_list.remove(command) - del self.cache[command] - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - """ - Delete multiple commands and their metadata from the cache. - - Args: - commands (List[Union[str, Sequence[str]]]): The list of commands to be - deleted. - """ - for command in commands: - self.delete_command(command) - - def flush(self): - """Clear the entire cache, removing all redis commands and metadata.""" - self.cache.clear() - self.key_commands_map.clear() - self.commands_ttl_list = [] - - def _is_expired(self, command: Union[str, Sequence[str]]) -> bool: - """ - Check if a redis command has expired based on its time-to-live. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - bool: True if the command has expired, False otherwise. - """ - if self.ttl == 0: - return False - return time.monotonic() - self.cache[command]["ctime"] > self.ttl - - def _update_access(self, command: Union[str, Sequence[str]]): - """ - Update the access information for a redis command based on the eviction policy. - - Args: - command (Union[str, Sequence[str]]): The redis command. - """ - if self.eviction_policy == EvictionPolicy.LRU: - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU: - self.cache[command]["access_count"] = ( - self.cache.get(command, {}).get("access_count", 0) + 1 - ) - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - pass # Random eviction doesn't require updates - - def _evict(self): - """Evict a redis command from the cache based on the eviction policy.""" - if self._is_expired(self.commands_ttl_list[0]): - self.delete_command(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU: - self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU: - min_access_command = min( - self.cache, key=lambda k: self.cache[k].get("access_count", 0) - ) - self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - random_command = random.choice(list(self.cache.keys())) - self.cache.pop(random_command) - - def _update_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Update the key_commands_map with command that uses the keys. - - Args: - keys (List[KeyT]): The list of keys used in the command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].add(command) - - def _del_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Remove a redis command from the key_commands_map. - - Args: - keys (List[KeyT]): The list of keys used in the redis command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].remove(command) - - def invalidate_key(self, key: KeyT): - """ - Invalidate (delete) all redis commands associated with a specific key. - - Args: - key (KeyT): The key to be invalidated. - """ - if key not in self.key_commands_map: - return - commands = list(self.key_commands_map[key]) - for command in commands: - self.delete_command(command) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index cc210b9df5..0e0a6655d2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -120,6 +120,12 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -128,19 +134,10 @@ def _read_response(self, disable_decoding=False, push_request=False): return response def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -155,7 +152,7 @@ def __init__(self, socket_read_size): self.pubsub_push_handler_func = self.handle_pubsub_push_response self.invalidation_push_handler_func = None - def handle_pubsub_push_response(self, response): + async def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -267,6 +264,12 @@ async def _read_response( response = await self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -275,19 +278,10 @@ async def _read_response( return response async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return await self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return await self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return await self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 1845b7252f..5d93c83b12 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -26,12 +26,6 @@ cast, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -239,13 +233,6 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Redis client. @@ -295,13 +282,6 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -624,31 +604,22 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - command_name = args[0] - keys = options.pop("keys", None) # keys are used only for client side caching pool = self.connection_pool + command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) - response_from_cache = await conn._get_from_local_cache(args) + + if self.single_connection_client: + await self._single_conn_lock.acquire() try: - if response_from_cache is not None: - return response_from_cache - else: - try: - if self.single_connection_client: - await self._single_conn_lock.acquire() - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: await pool.release(conn) @@ -677,24 +648,6 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - StrictRedis = Redis @@ -1331,7 +1284,6 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..cbceccf401 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -19,12 +19,6 @@ Union, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers import AsyncCommandsParser, Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -276,13 +270,6 @@ def __init__( ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ) -> None: if db: raise RedisClusterException( @@ -326,14 +313,6 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, - # Client cache related kwargs - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } if ssl: @@ -938,18 +917,6 @@ def lock( thread_local=thread_local, ) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: """ @@ -1076,25 +1043,16 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = self.acquire_connection() - keys = kwargs.pop("keys", None) - response_from_cache = await connection._get_from_local_cache(args) - if response_from_cache is not None: - self._free.append(connection) - return response_from_cache - else: - # Execute command - await connection.send_packed_command(connection.pack_command(*args), False) + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) - # Read response - try: - response = await self.parse_response(connection, args[0], **kwargs) - if keys: - connection._add_to_local_cache(args, response, keys) - return response - finally: - # Release connection - self._free.append(connection) + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection @@ -1121,18 +1079,6 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: return ret - def flush_cache(self): - for connection in self._connections: - connection.flush_cache() - - def delete_command_from_cache(self, command): - for connection in self._connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for connection in self._connections: - connection.invalidate_key_from_cache(key) - class NodesManager: __slots__ = ( @@ -1416,18 +1362,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ @@ -1516,7 +1450,6 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2ac6637986..ddbd22c95d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -49,16 +49,9 @@ ResponseError, TimeoutError, ) -from redis.typing import EncodableT, KeysT, ResponseT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes -from .._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, -) from .._parsers import ( BaseParser, Encoder, @@ -121,9 +114,6 @@ class AbstractConnection: "encoder", "ssl_context", "protocol", - "client_cache", - "cache_deny_list", - "cache_allow_list", "_reader", "_writer", "_parser", @@ -158,13 +148,6 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): if (username or password) and credential_provider is not None: raise DataError( @@ -222,18 +205,6 @@ def __init__( if p < 2 or p > 3: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage @@ -425,11 +396,6 @@ async def on_connect(self) -> None: # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) - # if client caching is enabled, start tracking - if self.client_cache: - await self.send_command("CLIENT", "TRACKING", "ON") - await self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -464,9 +430,6 @@ async def disconnect(self, nowait: bool = False) -> None: raise TimeoutError( f"Timed out closing connection after {self.socket_connect_timeout}" ) from None - finally: - if self.client_cache: - self.client_cache.flush() async def _send_ping(self): """Send PING, expect PONG in return""" @@ -688,60 +651,9 @@ def _socket_is_empty(self): """Check if the socket is empty""" return len(self._reader._buffer) == 0 - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - async def _get_from_local_cache(self, command: str): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None + async def process_invalidation_messages(self): while not self._socket_is_empty(): await self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() - - def delete_command_from_cache(self, command): - if self.client_cache: - self.client_cache.delete_command(command) - - def invalidate_key_from_cache(self, key): - if self.client_cache: - self.client_cache.invalidate_key(key) class Connection(AbstractConnection): @@ -1177,18 +1089,12 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if ( - await connection.can_read_destructive() - and connection.client_cache is None - ): + if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() @@ -1235,21 +1141,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6fd233adc8..5d4608ed2f 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -225,7 +225,6 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..c79d5af6a9 --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,248 @@ +from typing import Callable, TypeVar, Any, NoReturn, List, Union +from typing import Optional +from enum import Enum + +from cachetools import TTLCache, Cache, LRUCache +from cachetools.keys import hashkey + +from redis.typing import ResponseT + +T = TypeVar('T') + + +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" + + +class CacheConfiguration: + DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU + + DEFAULT_ALLOW_LIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", + ] + + def __init__(self, **kwargs): + self._max_size = kwargs.get("cache_size", 10000) + self._ttl = kwargs.get("cache_ttl", 0) + self._eviction_policy = kwargs.get("eviction_policy", self.DEFAULT_EVICTION_POLICY) + + def get_ttl(self) -> int: + return self._ttl + + def get_eviction_policy(self) -> EvictionPolicy: + return self._eviction_policy + + def is_exceeds_max_size(self, count: int) -> bool: + return count > self._max_size + + def is_allowed_to_cache(self, command: str) -> bool: + return command in self.DEFAULT_ALLOW_LIST + + +def ensure_string(key): + if isinstance(key, bytes): + return key.decode('utf-8') + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") + + +class CacheMixin: + def __init__(self, + use_cache: bool, + connection_pool: "ConnectionPool", + cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, + ) -> None: + self.use_cache = use_cache + if not use_cache: + return + if cache is not None: + self.cache = cache + else: + self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl) + self.keys_mapping = LRUCache(maxsize=10000) + self.wrap_connection_pool(connection_pool) + self.connections = [] + + def cached_call(self, + func: Callable[..., ResponseT], + *args, + **options) -> ResponseT: + if not self.use_cache: + return func(*args, **options) + + print(f'Cached call with args {args} and options {options}') + + keys = None + if 'keys' in options: + keys = options['keys'] + if not isinstance(keys, list): + raise TypeError("Cache keys must be a list.") + if not keys: + return func(*args, **options) + print(f'keys {keys}') + + cache_key = hashkey(*args) + + for conn in self.connections: + conn.process_invalidation_messages() + + for key in keys: + if key in self.keys_mapping: + if cache_key not in self.keys_mapping[key]: + self.keys_mapping[key].append(cache_key) + else: + self.keys_mapping[key] = [cache_key] + + if cache_key in self.cache: + result = self.cache[cache_key] + print(f'Cached call for {args} yields cached result {result}') + return result + + result = func(*args, **options) + self.cache[cache_key] = result + print(f'Cached call for {args} yields computed result {result}') + return result + + def get_cache_entry(self, *args: Any) -> Any: + cache_key = hashkey(*args) + return self.cache.get(cache_key, None) + + def invalidate_cache_entry(self, *args: Any) -> None: + cache_key = hashkey(*args) + if cache_key in self.cache: + self.cache.pop(cache_key) + + def wrap_connection_pool(self, connection_pool: "ConnectionPool"): + if not self.use_cache: + return + if connection_pool is None: + return + original_maker = connection_pool.make_connection + connection_pool.make_connection = lambda: self._make_connection(original_maker) + + def _make_connection(self, original_maker: Callable[[], "Connection"]): + conn = original_maker() + original_disconnect = conn.disconnect + conn.disconnect = lambda: self._wrapped_disconnect(conn, original_disconnect) + self.add_connection(conn) + return conn + + def _wrapped_disconnect(self, connection: "Connection", + original_disconnect: Callable[[], NoReturn]): + original_disconnect() + self.remove_connection(connection) + + def add_connection(self, conn): + print(f'Tracking connection {conn} {id(conn)}') + conn.register_connect_callback(self._on_connect) + self.connections.append(conn) + + def _on_connect(self, conn): + conn.send_command("CLIENT", "TRACKING", "ON") + response = conn.read_response() + print(f"Client tracking response {response}") + conn._parser.set_invalidation_push_handler(self._cache_invalidation_process) + + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + print(f'Invalidation {data}') + if data[1] is None: + self.cache.clear() + else: + for key in data[1]: + normalized_key = ensure_string(key) + print(f'Invalidating normalized key {normalized_key}') + if normalized_key in self.keys_mapping: + for cache_key in self.keys_mapping[normalized_key]: + print(f'Invalidating cache key {cache_key}') + self.cache.pop(cache_key) + + def remove_connection(self, conn): + print(f'Untracking connection {conn} {id(conn)}') + self.connections.remove(conn) diff --git a/redis/client.py b/redis/client.py index b7a1f88d92..adbf380b8e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,12 +6,8 @@ from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) +from cachetools import Cache + from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -19,6 +15,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import CacheMixin from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -89,7 +86,7 @@ class AbstractRedis: pass -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands, CacheMixin): """ Implementation of the Redis protocol. @@ -147,10 +144,12 @@ class initializer. In the case of conflicting arguments, querystring """ single_connection_client = kwargs.pop("single_connection_client", False) + use_cache = kwargs.pop("use_cache", False) connection_pool = ConnectionPool.from_url(url, **kwargs) client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, + use_cache=use_cache ) client.auto_close_connection_pool = True return client @@ -216,13 +215,10 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, + use_cache: bool = False, + cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, ) -> None: """ Initialize a new Redis client. @@ -274,13 +270,6 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -322,12 +311,25 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) + if use_cache and protocol in [3, "3"]: + kwargs.update( + { + "use_cache": use_cache, + "cache": cache, + "cache_size": cache_size, + "cache_ttl": cache_ttl, + } + ) connection_pool = ConnectionPool(**kwargs) self.auto_close_connection_pool = True else: self.auto_close_connection_pool = False self.connection_pool = connection_pool + + if use_cache and self.connection_pool.get_protocol() not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + self.connection = None if single_connection_client: self.connection = self.connection_pool.get_connection("_") @@ -541,7 +543,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response """ - conn.send_command(*args) + conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) def _disconnect_raise(self, conn, error): @@ -559,25 +561,20 @@ def _disconnect_raise(self, conn, error): # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): + return self._execute_command(*args, **options) + + def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" - command_name = args[0] - keys = options.pop("keys", None) pool = self.connection_pool + command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) - response_from_cache = conn._get_from_local_cache(args) try: - if response_from_cache is not None: - return response_from_cache - else: - response = conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: if not self.connection: pool.release(conn) @@ -602,24 +599,6 @@ def parse_response(self, connection, command_name, **options): return self.response_callbacks[command_name](response, **options) return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - StrictRedis = Redis @@ -1314,7 +1293,6 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index be7685e9a1..39e8c4b9ea 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,9 +6,12 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from cachetools import Cache + from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff +from redis.cache import CacheMixin from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -167,13 +170,7 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", - "cache_enabled", - "client_cache", - "cache_max_size", - "cache_ttl", - "cache_policy", - "cache_deny_list", - "cache_allow_list", + "use_cache", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -449,7 +446,7 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: self.nodes_manager.default_node = random.choice(replicas) -class RedisCluster(AbstractRedisCluster, RedisClusterCommands): +class RedisCluster(AbstractRedisCluster, RedisClusterCommands, CacheMixin): @classmethod def from_url(cls, url, **kwargs): """ @@ -507,6 +504,7 @@ def __init__( dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + use_cache: Optional[bool] = False, **kwargs, ): """ @@ -642,6 +640,7 @@ def __init__( require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, + use_cache=use_cache, **kwargs, ) @@ -649,6 +648,12 @@ def __init__( self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + + protocol = kwargs.get("protocol", None) + if use_cache and protocol not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + CacheMixin.__init__(self, use_cache, None) + self.commands_parser = CommandsParser(self) self._lock = threading.Lock() @@ -1051,7 +1056,12 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def execute_command(self, *args, **kwargs): + def execute_command(self, *args, **options): + if self.use_cache: + return self.cached_call(self._execute_command, *args, **options) + return self._internal_execute_command(*args, **options) + + def _internal_execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. @@ -1125,7 +1135,6 @@ def _execute_command(self, target_node, *args, **kwargs): """ Send a command to a node in the cluster """ - keys = kwargs.pop("keys", None) command = args[0] redis_node = None connection = None @@ -1154,19 +1163,13 @@ def _execute_command(self, target_node, *args, **kwargs): connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - response_from_cache = connection._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - if keys: - connection._add_to_local_cache(args, response, keys) - return response + connection.send_command(*args) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + return response except AuthenticationError: raise except (ConnectionError, TimeoutError) as e: @@ -1266,18 +1269,6 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1306,18 +1297,6 @@ def __del__(self): if self.redis_connection is not None: self.redis_connection.close() - def flush_cache(self): - if self.redis_connection is not None: - self.redis_connection.flush_cache() - - def delete_command_from_cache(self, command): - if self.redis_connection is not None: - self.redis_connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.redis_connection is not None: - self.redis_connection.invalidate_key_from_cache(key) - class LoadBalancer: """ @@ -1338,7 +1317,7 @@ def reset(self) -> None: self.primary_to_idx.clear() -class NodesManager: +class NodesManager(CacheMixin): def __init__( self, startup_nodes, @@ -1348,6 +1327,8 @@ def __init__( dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + use_cache: Optional[bool] = False, + cache: Optional[Cache] = None, **kwargs, ): self.nodes_cache = {} @@ -1360,12 +1341,14 @@ def __init__( self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap + self.use_cache = use_cache self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() if lock is None: lock = threading.Lock() self._lock = lock + CacheMixin.__init__(self, use_cache, None, cache) self.initialize() def get_node(self, host=None, port=None, node_name=None): @@ -1503,9 +1486,9 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - r = Redis(connection_pool=self.connection_pool_class(**kwargs)) + r = Redis(connection_pool=self.connection_pool_class(**kwargs), use_cache=self.use_cache, cache=self.cache) else: - r = Redis(host=host, port=port, **kwargs) + r = Redis(host=host, port=port, use_cache=self.use_cache, cache=self.cache, **kwargs) return r def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): @@ -1681,18 +1664,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPubSub(PubSub): """ @@ -2008,7 +1979,6 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): diff --git a/redis/connection.py b/redis/connection.py index 1f862d0371..5d3640e816 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,16 +9,12 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Sequence, Type, Union +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse +from cachetools import TTLCache, Cache, LRUCache +from cachetools.keys import hashkey +from redis.cache import CacheConfiguration -from ._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, -) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -33,7 +29,6 @@ TimeoutError, ) from .retry import Retry -from .typing import KeysT, ResponseT from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, @@ -107,9 +102,9 @@ def pack(self, *args): # output list if we're sending large values or memoryviews arg_length = len(arg) if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) ): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) @@ -132,39 +127,96 @@ def pack(self, *args): return output -class AbstractConnection: +class ConnectionInterface: + @abstractmethod + def repr_pieces(self): + pass + + @abstractmethod + def register_connect_callback(self, callback): + pass + + @abstractmethod + def deregister_connect_callback(self, callback): + pass + + @abstractmethod + def set_parser(self, parser_class): + pass + + @abstractmethod + def connect(self): + pass + + @abstractmethod + def on_connect(self): + pass + + @abstractmethod + def disconnect(self, *args): + pass + + @abstractmethod + def check_health(self): + pass + + @abstractmethod + def send_packed_command(self, command, check_health=True): + pass + + @abstractmethod + def send_command(self, *args, **kwargs): + pass + + @abstractmethod + def can_read(self, timeout=0): + pass + + @abstractmethod + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + pass + + @abstractmethod + def pack_command(self, *args): + pass + + @abstractmethod + def pack_commands(self, commands): + pass + + +class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" def __init__( - self, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - retry_on_timeout: bool = False, - retry_on_error=SENTINEL, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - parser_class=DefaultParser, - socket_read_size: int = 65536, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "redis-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - retry: Union[Any, None] = None, - redis_connect_func: Optional[Callable[[], None]] = None, - credential_provider: Optional[CredentialProvider] = None, - protocol: Optional[int] = 2, - command_packer: Optional[Callable[[], None]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, + self, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error=SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. @@ -230,18 +282,6 @@ def __init__( # p = DEFAULT_RESP_VERSION self.protocol = p self._command_packer = self._construct_command_packer(command_packer) - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -351,8 +391,8 @@ def on_connect(self): # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( - self.credential_provider - or UsernamePasswordCredentialProvider(self.username, self.password) + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() @@ -400,8 +440,8 @@ def on_connect(self): self.send_command("HELLO", self.protocol) response = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -432,12 +472,6 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") - # if client caching is enabled, start tracking - if self.client_cache: - self.send_command("CLIENT", "TRACKING", "ON") - self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) - def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() @@ -458,9 +492,6 @@ def disconnect(self, *args): except OSError: pass - if self.client_cache: - self.client_cache.flush() - def _send_ping(self): """Send PING, expect PONG in return""" self.send_command("PING", check_health=False) @@ -529,11 +560,11 @@ def can_read(self, timeout=0): raise ConnectionError(f"Error while reading from {host_error}: {e.args}") def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): """Read the response from a previously sent command""" @@ -589,9 +620,9 @@ def pack_commands(self, commands): for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( - buffer_length > buffer_cutoff - or chunklen > buffer_cutoff - or isinstance(chunk, memoryview) + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) ): if pieces: output.append(SYM_EMPTY.join(pieces)) @@ -608,73 +639,21 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - def _get_from_local_cache(self, command: Sequence[str]): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None - while self.can_read(): - self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Sequence[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() - - def delete_command_from_cache(self, command: Union[str, Sequence[str]]): - if self.client_cache: - self.client_cache.delete_command(command) - - def invalidate_key_from_cache(self, key: KeysT): - if self.client_cache: - self.client_cache.invalidate_key(key) + def get_protocol(self) -> int or str: + return self.protocol class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" def __init__( - self, - host="localhost", - port=6379, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, - **kwargs, + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, ): self.host = host self.port = int(port) @@ -696,7 +675,7 @@ def _connect(self): # socket.connect() err = None for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -734,6 +713,151 @@ def _host_error(self): return f"{self.host}:{self.port}" +def ensure_string(key): + if isinstance(key, bytes): + return key.decode('utf-8') + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") + + +class CacheProxyConnection(ConnectionInterface): + def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration): + self.pid = os.getpid() + self._conn = conn + self.retry = self._conn.retry + self._cache = cache + self._conf = conf + self._current_command_hash = None + self._current_command_keys = None + self._current_options = None + self._keys_mapping = LRUCache(maxsize=10000) + self.register_connect_callback(self._enable_tracking_callback) + + def repr_pieces(self): + return self._conn.repr_pieces() + + def register_connect_callback(self, callback): + self._conn.register_connect_callback(callback) + + def deregister_connect_callback(self, callback): + self._conn.deregister_connect_callback(callback) + + def set_parser(self, parser_class): + self._conn.set_parser(parser_class) + + def connect(self): + self._conn.connect() + + def on_connect(self): + self._conn.on_connect() + + def disconnect(self, *args): + self._conn.disconnect(*args) + + def check_health(self): + self._conn.check_health() + + def send_packed_command(self, command, check_health=True): + cache_key = hashkey(command) + + if self._cache.get(cache_key): + self._current_command_hash = cache_key + return + + self._current_command_hash = None + self._conn.send_packed_command(command) + + def send_command(self, *args, **kwargs): + if not self._conf.is_allowed_to_cache(args[0]): + self._current_command_hash = None + self._current_command_keys = None + self._conn.send_command(*args, **kwargs) + return + + self._current_command_hash = hashkey(*args) + + if kwargs.get("keys"): + self._current_command_keys = kwargs["keys"] + + if not isinstance(self._current_command_keys, list): + raise TypeError("Cache keys must be a list.") + + if self._cache.get(self._current_command_hash): + return + + self._conn.send_command(*args, **kwargs) + + def can_read(self, timeout=0): + return self._conn.can_read(timeout) + + def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): + response = self._conn.read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request + ) + + if isinstance(response, List) and len(response) > 0 and response[0] == 'invalidate': + self._on_invalidation_callback(response) + self.read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request + ) + + if response is None or self._current_command_hash is None: + return response + + if self._current_command_hash in self._cache: + return self._cache[self._current_command_hash] + + for key in self._current_command_keys: + if key in self._keys_mapping: + if self._current_command_hash not in self._keys_mapping[key]: + self._keys_mapping[key].append(self._current_command_hash) + else: + self._keys_mapping[key] = [self._current_command_hash] + + self._cache[self._current_command_hash] = response + return response + + def pack_command(self, *args): + pass + + def pack_commands(self, commands): + pass + + def _connect(self): + self._conn._connect() + + def _host_error(self): + self._conn._host_error() + + def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: + conn.send_command('CLIENT', 'TRACKING', 'ON') + conn.read_response() + + def _process_pending_invalidations(self): + print(f'connection {self} {id(self)} process invalidations') + while self.can_read(): + self.read_response(push_request=True) + + def _on_invalidation_callback( + self, data: List[Union[str, Optional[List[str]]]] + ): + if data[1] is None: + self._cache.clear() + else: + for key in data[1]: + normalized_key = ensure_string(key) + if normalized_key in self._keys_mapping: + for cache_key in self._keys_mapping[normalized_key]: + self._cache.pop(cache_key) + + + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). This class extends the Connection class, adding SSL functionality, and making @@ -741,22 +865,22 @@ class SSLConnection(Connection): """ # noqa def __init__( - self, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_ca_path=None, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - **kwargs, + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_ca_path=None, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + ssl_min_version=None, + ssl_ciphers=None, + **kwargs, ): """Constructor @@ -843,9 +967,9 @@ def _wrap_socket_with_ssl(self, sock): password=self.certificate_password, ) if ( - self.ca_certs is not None - or self.ca_path is not None - or self.ca_data is not None + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None ): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data @@ -961,9 +1085,9 @@ def to_bool(value): def parse_url(url): if not ( - url.startswith("redis://") - or url.startswith("rediss://") - or url.startswith("unix://") + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") ): raise ValueError( "Redis URL must specify one of the following " @@ -1080,18 +1204,37 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, - connection_class=Connection, - max_connections: Optional[int] = None, - **connection_kwargs, + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): - max_connections = max_connections or 2**31 + max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections + self._cache = None + self._cache_conf = None + + if connection_kwargs.get("use_cache"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + + self._cache_conf = CacheConfiguration(**self.connection_kwargs) + + if self.connection_kwargs.get("cache"): + self._cache = self.connection_kwargs.get("cache") + else: + self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) + + connection_kwargs.pop("use_cache", None) + connection_kwargs.pop("cache_size", None) + connection_kwargs.pop("cache_ttl", None) + connection_kwargs.pop("cache", None) + # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1110,6 +1253,14 @@ def __repr__(self) -> (str, str): f"({repr(self.connection_class(**self.connection_kwargs))})>" ) + def get_protocol(self): + """ + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. + """ + return self.connection_kwargs.get("protocol", None) + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 @@ -1187,15 +1338,12 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if connection.can_read() and connection.client_cache is None: + if connection.can_read(): raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() @@ -1219,11 +1367,15 @@ def get_encoder(self) -> Encoder: decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self) -> "Connection": + def make_connection(self) -> "ConnectionInterface": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 + + if self._cache is not None and self._cache_conf is not None: + return CacheProxyConnection(self.connection_class(**self.connection_kwargs), self._cache, self._cache_conf) + return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1281,27 +1433,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ @@ -1338,12 +1469,12 @@ class BlockingConnectionPool(ConnectionPool): """ def __init__( - self, - max_connections=50, - timeout=20, - connection_class=Connection, - queue_class=LifoQueue, - **connection_kwargs, + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout diff --git a/redis/sentinel.py b/redis/sentinel.py index 72b5bef548..e0437c81cd 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -252,7 +252,6 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/requirements.txt b/requirements.txt index 3274a80f62..26aed50b9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ async-timeout>=4.0.3 +cachetools diff --git a/tests/conftest.py b/tests/conftest.py index dd78bb6a2c..97d73773ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -441,7 +441,6 @@ def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6e93407b4c..41b47b2268 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -146,7 +146,6 @@ def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py deleted file mode 100644 index 7a7f881ce2..0000000000 --- a/tests/test_asyncio/test_cache.py +++ /dev/null @@ -1,408 +0,0 @@ -import time - -import pytest -import pytest_asyncio -from redis._cache import EvictionPolicy, _LocalCache -from redis.utils import HIREDIS_AVAILABLE - - -@pytest_asyncio.fixture -async def r(request, create_redis): - cache = request.param.get("cache") - kwargs = request.param.get("kwargs", {}) - r = await create_redis(protocol=3, client_cache=cache, **kwargs) - yield r, cache - - -@pytest_asyncio.fixture() -async def local_cache(): - yield _LocalCache() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - @pytest.mark.onlynoncluster - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) - async def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - async def test_cache_ttl(self, r): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], - indirect=True, - ) - async def test_cache_lfu_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_allow_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_cache_return_copy(self, r): - r, cache = r - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = await r.client_id() - await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - id2 = await r.client_id() - - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ - "1", - "1", - "1", - "1", - "1", - ] - - await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) - id3 = await r.client_id() - # client should get value from redis server post invalidate messages - assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] - - await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) - # need to check that we get correct value 3 and not 2 - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - - await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) - # need to check that we get correct value 4 and not 3 - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - id4 = await r.client_id() - assert id1 == id2 == id3 == id4 - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_delete_one_command(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_invalidate_key(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_flush_entire_cache(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: - - async def test_get_from_cache(self, local_cache, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, local_cache, sentinel_setup, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == "barbar" diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a36040f11b..57dfd25fb6 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -190,7 +190,6 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) @@ -201,7 +200,6 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8f79f7d947..e584fc6999 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -75,7 +75,6 @@ async def call_with_retry(self, _, __): mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() - mock_conn._get_from_local_cache.return_value = None async def get_conn(_): # Validate only one client is created in single-client mode when diff --git a/tests/test_cache.py b/tests/test_cache.py index 022364e87a..4eda78ebbb 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,587 +1,35 @@ import time -from collections import defaultdict -from typing import List, Sequence, Union -import cachetools -import pytest -import redis -from redis import RedisError -from redis._cache import AbstractCache, EvictionPolicy, _LocalCache -from redis.typing import KeyT, ResponseT -from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import _get_client +from redis import Redis, RedisCluster -@pytest.fixture() -def r(request): - cache = request.param.get("cache") - kwargs = request.param.get("kwargs", {}) - protocol = request.param.get("protocol", 3) - single_connection_client = request.param.get("single_connection_client", False) - with _get_client( - redis.Redis, - request, - single_connection_client=single_connection_client, - protocol=protocol, - client_cache=cache, - **kwargs, - ) as client: - yield client, cache +def test_standalone_cached_get_and_set(): + r = Redis(use_cache=True, protocol=3) + assert r.set("key", 5) + assert r.get("key") == b"5" + r2 = Redis(protocol=3) + r2.set("key", "foo") -@pytest.fixture() -def local_cache(): - return _LocalCache() + time.sleep(0.5) + after_invalidation = r.get("key") + print(f'after invalidation {after_invalidation}') + assert after_invalidation == b"foo" -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3)}], - indirect=True, - ) - def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None +def test_cluster_cached_get_and_set(): + cluster_url = "redis://localhost:16379/0" - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - def test_cache_ttl(self, r): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) + assert r.set("key", 5) + assert r.get("key") == b"5" - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], - indirect=True, - ) - def test_cache_lfu_eviction(self, r): - r, cache = r - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None + r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) + r2.set("key", "foo") - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_cache_decode_response(self, r): - r, cache = r - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], - indirect=True, - ) - def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - def test_cache_allow_list(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_cache_return_copy(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = r.client_id() - r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - id2 = r.client_id() - - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ - "1", - "1", - "1", - "1", - "1", - "1", - ] - - r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) - id3 = r.client_id() - # client should get value from redis server post invalidate messages - assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] - - r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) - # need to check that we get correct value 3 and not 2 - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - - r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) - # need to check that we get correct value 4 and not 3 - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - id4 = r.client_id() - assert id1 == id2 == id3 == id4 - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_multiple_commands_same_key(self, r): - r, cache = r - r.mset({"a": 1, "b": 1}) - assert r.mget("a", "b") == ["1", "1"] - # value should be in local cache - assert cache.get(("MGET", "a", "b")) == ["1", "1"] - # set only one key - r.set("a", 2) - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("MGET", "a", "b")) is None - # get from redis - assert r.mget("a", "b") == ["2", "1"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_delete_one_command(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_delete_several_commands(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete the commands from the cache - cache.delete_commands([("MGET", "a{a}", "b{a}"), ("GET", "c")]) - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_invalidate_key(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_flush_entire_cache(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.onlynoncluster - def test_cache_not_available_with_resp2(self, request): - with pytest.raises(RedisError) as e: - _get_client(redis.Redis, request, protocol=2, client_cache=_LocalCache()) - assert "protocol version 3 or higher" in str(e.value) - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_execute_command_args_not_split(self, r): - r, cache = r - assert r.execute_command("SET a 1") == "OK" - assert r.execute_command("GET a") == "1" - # "get a" is not whitelisted by default, the args should be separated - assert cache.get(("GET a",)) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "single_connection_client": True}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_single_connection(self, r): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache_invalidate_via_get(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # don't send any command to redis, just run another get - # it should process the invalidation in background - assert r.get("foo") == b"barbar" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_cache_decode_response(self, r): - r, cache = r - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: - - def test_get_from_cache(self, local_cache, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_cache_decode_response(self, local_cache, sentinel_setup, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == "barbar" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestCustomCache: - class _CustomCache(AbstractCache): - def __init__(self): - self.responses = cachetools.LRUCache(maxsize=1000) - self.keys_to_commands = defaultdict(list) - self.commands_to_keys = defaultdict(list) - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - self.responses[command] = response - for key in keys_in_command: - self.keys_to_commands[key].append(tuple(command)) - self.commands_to_keys[command].append(tuple(keys_in_command)) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - return self.responses.get(command) - - def delete_command(self, command: Union[str, Sequence[str]]): - self.responses.pop(command, None) - keys = self.commands_to_keys.pop(command, []) - for key in keys: - if command in self.keys_to_commands[key]: - self.keys_to_commands[key].remove(command) - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - for command in commands: - self.delete_command(command) - - def flush(self): - self.responses.clear() - self.commands_to_keys.clear() - self.keys_to_commands.clear() - - def invalidate_key(self, key: KeyT): - commands = self.keys_to_commands.pop(key, []) - for command in commands: - self.delete_command(command) - - @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + time.sleep(0.5) + + after_invalidation = r.get("key") + print(f'after invalidation {after_invalidation}') + assert after_invalidation == b"foo" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 5a32bd6a7e..229e0fc6e6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -208,7 +208,6 @@ def cmd_init_mock(self, r): def mock_node_resp(node, response): connection = Mock() connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -216,7 +215,6 @@ def mock_node_resp(node, response): def mock_node_resp_func(node, func): connection = Mock() connection.read_response.side_effect = func - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -485,7 +483,6 @@ def mock_execute_command(*_args, **_kwargs): redis_mock_node.execute_command.side_effect = mock_execute_command # Mock response value for all other commands redis_mock_node.parse_response.return_value = "MOCK_OK" - redis_mock_node.connection._get_from_local_cache.return_value = None for node in r.get_nodes(): if node.port != primary.port: node.redis_connection = redis_mock_node