Skip to content

Commit 65bd5af

Browse files
committed
Added CacheInterface abstraction
1 parent c685248 commit 65bd5af

File tree

5 files changed

+179
-68
lines changed

5 files changed

+179
-68
lines changed

redis/cache.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
3+
from typing import Any, Hashable
34

45
from cachetools import Cache, LFUCache, LRUCache, RRCache, TTLCache
56

@@ -121,31 +122,112 @@ def is_allowed_to_cache(self, command: str) -> bool:
121122
return command in self.DEFAULT_ALLOW_LIST
122123

123124

124-
class CacheClass(Enum):
125+
class EvictionPolicyCacheClass(Enum):
125126
LRU = LRUCache
126127
LFU = LFUCache
127128
RANDOM = RRCache
128129
TTL = TTLCache
129130

130131

132+
class CacheClassEvictionPolicy(Enum):
133+
LRUCache = EvictionPolicy.LRU
134+
LFUCache = EvictionPolicy.LFU
135+
RRCache = EvictionPolicy.RANDOM
136+
TTLCache = EvictionPolicy.TTL
137+
138+
139+
class CacheInterface(ABC):
140+
141+
@property
142+
@abstractmethod
143+
def currsize(self) -> float:
144+
pass
145+
146+
@property
147+
@abstractmethod
148+
def maxsize(self) -> float:
149+
pass
150+
151+
@property
152+
@abstractmethod
153+
def eviction_policy(self) -> EvictionPolicy:
154+
pass
155+
156+
@abstractmethod
157+
def get(self, key: Hashable, default: Any = None):
158+
pass
159+
160+
@abstractmethod
161+
def set(self, key: Hashable, value: Any):
162+
pass
163+
164+
@abstractmethod
165+
def exists(self, key: Hashable) -> bool:
166+
pass
167+
168+
@abstractmethod
169+
def remove(self, key: Hashable):
170+
pass
171+
172+
@abstractmethod
173+
def clear(self):
174+
pass
175+
176+
131177
class CacheFactoryInterface(ABC):
132178
@abstractmethod
133-
def get_cache(self) -> Cache:
179+
def get_cache(self) -> CacheInterface:
134180
pass
135181

136182

137-
class CacheFactory(CacheFactoryInterface):
183+
class CacheToolsFactory(CacheFactoryInterface):
138184
def __init__(self, conf: CacheConfiguration):
139185
self._conf = conf
140186

141-
def get_cache(self) -> Cache:
187+
def get_cache(self) -> CacheInterface:
142188
eviction_policy = self._conf.get_eviction_policy()
143189
cache_class = self._get_cache_class(eviction_policy).value
144190

145191
if eviction_policy == EvictionPolicy.TTL:
146-
return cache_class(self._conf.get_max_size(), self._conf.get_ttl())
192+
cache_inst = cache_class(self._conf.get_max_size(), self._conf.get_ttl())
193+
else:
194+
cache_inst = cache_class(self._conf.get_max_size())
195+
196+
return CacheToolsAdapter(cache_inst)
197+
198+
def _get_cache_class(
199+
self, eviction_policy: EvictionPolicy
200+
) -> EvictionPolicyCacheClass:
201+
return EvictionPolicyCacheClass[eviction_policy.value]
202+
203+
204+
class CacheToolsAdapter(CacheInterface):
205+
def __init__(self, cache: Cache):
206+
self._cache = cache
207+
208+
def get(self, key: Hashable, default: Any = None):
209+
return self._cache.get(key, default)
210+
211+
def set(self, key: Hashable, value: Any):
212+
self._cache[key] = value
213+
214+
def exists(self, key: Hashable) -> bool:
215+
return key in self._cache
216+
217+
def remove(self, key: Hashable):
218+
self._cache.pop(key)
219+
220+
def clear(self):
221+
self._cache.clear()
222+
223+
@property
224+
def currsize(self) -> float:
225+
return self._cache.currsize
147226

148-
return cache_class(self._conf.get_max_size())
227+
@property
228+
def maxsize(self) -> float:
229+
return self._cache.maxsize
149230

150-
def _get_cache_class(self, eviction_policy: EvictionPolicy) -> CacheClass:
151-
return CacheClass[eviction_policy.value]
231+
@property
232+
def eviction_policy(self) -> EvictionPolicy:
233+
return CacheClassEvictionPolicy[self._cache.__class__.__name__].value

redis/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_RedisCallbacksRESP3,
1515
bool_ok,
1616
)
17-
from redis.cache import EvictionPolicy
17+
from redis.cache import CacheInterface, EvictionPolicy
1818
from redis.commands import (
1919
CoreCommands,
2020
RedisModuleCommands,
@@ -215,7 +215,7 @@ def __init__(
215215
credential_provider: Optional[CredentialProvider] = None,
216216
protocol: Optional[int] = 2,
217217
use_cache: bool = False,
218-
cache: Optional[Cache] = None,
218+
cache: Optional[CacheInterface] = None,
219219
cache_eviction: Optional[EvictionPolicy] = None,
220220
cache_size: int = 128,
221221
cache_ttl: int = 300,
@@ -603,7 +603,7 @@ def parse_response(self, connection, command_name, **options):
603603
return self.response_callbacks[command_name](response, **options)
604604
return response
605605

606-
def get_cache(self) -> Optional[Cache]:
606+
def get_cache(self) -> Optional[CacheInterface]:
607607
return self.connection_pool.cache
608608

609609

redis/cluster.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from redis._parsers import CommandsParser, Encoder
1111
from redis._parsers.helpers import parse_scan
1212
from redis.backoff import default_backoff
13-
from redis.cache import EvictionPolicy
13+
from redis.cache import CacheInterface, EvictionPolicy
1414
from redis.client import CaseInsensitiveDict, PubSub, Redis
1515
from redis.commands import READ_COMMANDS, RedisClusterCommands
1616
from redis.commands.helpers import list_or_args
@@ -170,6 +170,9 @@ def parse_cluster_myshardid(resp, **options):
170170
"unix_socket_path",
171171
"username",
172172
"use_cache",
173+
"cache",
174+
"cache_size",
175+
"cache_ttl",
173176
)
174177
KWARGS_DISABLED_KEYS = ("host", "port")
175178

@@ -504,7 +507,7 @@ def __init__(
504507
url: Optional[str] = None,
505508
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
506509
use_cache: bool = False,
507-
cache: Optional[Cache] = None,
510+
cache: Optional[CacheInterface] = None,
508511
cache_eviction: Optional[EvictionPolicy] = None,
509512
cache_size: int = 128,
510513
cache_ttl: int = 300,

redis/connection.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from apscheduler.schedulers.background import BackgroundScheduler
1616
from cachetools import Cache, LRUCache
1717
from cachetools.keys import hashkey
18-
from redis.cache import CacheConfiguration, CacheFactory
18+
from redis.cache import (
19+
CacheConfiguration,
20+
CacheFactoryInterface,
21+
CacheInterface,
22+
CacheToolsFactory,
23+
)
1924

2025
from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
2126
from .backoff import NoBackoff
@@ -728,7 +733,7 @@ class CacheProxyConnection(ConnectionInterface):
728733
def __init__(
729734
self,
730735
conn: ConnectionInterface,
731-
cache: Cache,
736+
cache: CacheInterface,
732737
conf: CacheConfiguration,
733738
cache_lock: threading.Lock,
734739
):
@@ -806,7 +811,7 @@ def send_command(self, *args, **kwargs):
806811

807812
# Set temporary entry as a status to prevent
808813
# race condition from another connection.
809-
self._cache[self._current_command_hash] = "caching-in-progress"
814+
self._cache.set(self._current_command_hash, "caching-in-progress")
810815

811816
# Send command over socket only if it's allowed
812817
# read-only command that not yet cached.
@@ -821,10 +826,10 @@ def read_response(
821826
with self._cache_lock:
822827
# Check if command response exists in a cache and it's not in progress.
823828
if (
824-
self._current_command_hash in self._cache
825-
and self._cache[self._current_command_hash] != "caching-in-progress"
829+
self._cache.exists(self._current_command_hash)
830+
and self._cache.get(self._current_command_hash) != "caching-in-progress"
826831
):
827-
return copy.deepcopy(self._cache[self._current_command_hash])
832+
return copy.deepcopy(self._cache.get(self._current_command_hash))
828833

829834
response = self._conn.read_response(
830835
disable_decoding=disable_decoding,
@@ -835,7 +840,7 @@ def read_response(
835840
with self._cache_lock:
836841
# If response is None prevent from caching.
837842
if response is None:
838-
self._cache.pop(self._current_command_hash)
843+
self._cache.remove(self._current_command_hash)
839844
return response
840845
# Prevent not-allowed command from caching.
841846
elif self._current_command_hash is None:
@@ -855,7 +860,7 @@ def read_response(
855860
# Cache only responses that still valid
856861
# and wasn't invalidated by another connection in meantime.
857862
if cache_entry is not None:
858-
self._cache[self._current_command_hash] = response
863+
self._cache.set(self._current_command_hash, response)
859864

860865
return response
861866

@@ -892,7 +897,7 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[str]]]])
892897
# Make sure that all command responses
893898
# associated with this key will be deleted
894899
for cache_key in self._keys_mapping[normalized_key]:
895-
self._cache.pop(cache_key)
900+
self._cache.remove(cache_key)
896901
# Removes key from mapping cache
897902
self._keys_mapping.pop(normalized_key)
898903

@@ -1246,6 +1251,7 @@ def __init__(
12461251
self,
12471252
connection_class=Connection,
12481253
max_connections: Optional[int] = None,
1254+
cache_factory: Optional[CacheFactoryInterface] = None,
12491255
**connection_kwargs,
12501256
):
12511257
max_connections = max_connections or 2**31
@@ -1257,7 +1263,7 @@ def __init__(
12571263
self.max_connections = max_connections
12581264
self.cache = None
12591265
self._cache_conf = None
1260-
self._cache_factory = None
1266+
self._cache_factory = cache_factory
12611267
self.cache_lock = None
12621268
self.scheduler = None
12631269

@@ -1269,11 +1275,17 @@ def __init__(
12691275
self._cache_lock = threading.Lock()
12701276

12711277
cache = self.connection_kwargs.get("cache")
1278+
12721279
if cache is not None:
1280+
if not isinstance(cache, CacheInterface):
1281+
raise ValueError("Cache must implement CacheInterface")
1282+
12731283
self.cache = cache
12741284
else:
1275-
cache_factory = CacheFactory(self._cache_conf)
1276-
self.cache = cache_factory.get_cache()
1285+
if self._cache_factory is not None:
1286+
self.cache = self._cache_factory.get_cache()
1287+
else:
1288+
self.cache = CacheToolsFactory(self._cache_conf).get_cache()
12771289

12781290
self.scheduler = BackgroundScheduler()
12791291
self.scheduler.add_job(

0 commit comments

Comments
 (0)