Skip to content

Commit a734303

Browse files
committed
Added test cases and scheduler dependency
1 parent 2c0c812 commit a734303

File tree

4 files changed

+228
-36
lines changed

4 files changed

+228
-36
lines changed

redis/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
_RedisCallbacksRESP3,
1616
bool_ok,
1717
)
18-
from redis.cache import CacheMixin
1918
from redis.commands import (
2019
CoreCommands,
2120
RedisModuleCommands,
@@ -86,7 +85,7 @@ class AbstractRedis:
8685
pass
8786

8887

89-
class Redis(RedisModuleCommands, CoreCommands, SentinelCommands, CacheMixin):
88+
class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
9089
"""
9190
Implementation of the Redis protocol.
9291

redis/connection.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from time import time
1212
from typing import Any, Callable, List, Optional, Type, Union
1313
from urllib.parse import parse_qs, unquote, urlparse
14+
15+
from apscheduler.schedulers.background import BackgroundScheduler
1416
from cachetools import TTLCache, Cache, LRUCache
1517
from cachetools.keys import hashkey
1618
from redis.cache import CacheConfiguration
@@ -788,15 +790,21 @@ def send_command(self, *args, **kwargs):
788790
if self._cache.get(self._current_command_hash):
789791
return
790792

791-
# Send command over socket only if it's read-only command that not yet cached.
793+
# Set temporary entry as a status to prevent race condition from another connection.
794+
self._cache[self._current_command_hash] = "caching-in-progress"
795+
796+
# Send command over socket only if it's allowed read-only command that not yet cached.
792797
self._conn.send_command(*args, **kwargs)
793798

794799
def can_read(self, timeout=0):
795800
return self._conn.can_read(timeout)
796801

797802
def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False):
798-
# Check if command response exists in a cache.
799-
if self._current_command_hash in self._cache:
803+
# Check if command response exists in a cache and it's not in progress.
804+
if (
805+
self._current_command_hash in self._cache
806+
and self._cache[self._current_command_hash] != "caching-in-progress"
807+
):
800808
return self._cache[self._current_command_hash]
801809

802810
response = self._conn.read_response(
@@ -805,8 +813,12 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus
805813
push_request=push_request
806814
)
807815

808-
# Check if command that was sent is write command to prevent caching of write replies.
809-
if response is None or self._current_command_hash is None:
816+
# If response is None prevent from caching and remove temporary cache entry.
817+
if response is None:
818+
self._cache.pop(self._current_command_hash)
819+
return response
820+
# Prevent not-allowed command from caching.
821+
elif self._current_command_hash is None:
810822
return response
811823

812824
# Create separate mapping for keys or add current response to associated keys.
@@ -817,7 +829,12 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus
817829
else:
818830
self._keys_mapping[key] = [self._current_command_hash]
819831

820-
self._cache[self._current_command_hash] = response
832+
cache_entry = self._cache.get(self._current_command_hash, None)
833+
834+
# Cache only responses that still valid and wasn't invalidated by another connection in meantime
835+
if cache_entry is not None:
836+
self._cache[self._current_command_hash] = response
837+
821838
return response
822839

823840
def pack_command(self, *args):
@@ -1218,22 +1235,28 @@ def __init__(
12181235
self.max_connections = max_connections
12191236
self._cache = None
12201237
self._cache_conf = None
1238+
self._scheduler = None
12211239

12221240
if connection_kwargs.get("use_cache"):
12231241
if connection_kwargs.get("protocol") not in [3, "3"]:
12241242
raise RedisError("Client caching is only supported with RESP version 3")
12251243

12261244
self._cache_conf = CacheConfiguration(**self.connection_kwargs)
12271245

1228-
if self.connection_kwargs.get("cache"):
1229-
self._cache = self.connection_kwargs.get("cache")
1246+
cache = self.connection_kwargs.get("cache")
1247+
if cache is not None:
1248+
self._cache = cache
12301249
else:
12311250
self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"])
12321251

1233-
connection_kwargs.pop("use_cache", None)
1234-
connection_kwargs.pop("cache_size", None)
1235-
connection_kwargs.pop("cache_ttl", None)
1236-
connection_kwargs.pop("cache", None)
1252+
# self.scheduler = BackgroundScheduler()
1253+
# self.scheduler.add_job(self._perform_health_check, "interval", seconds=2)
1254+
# self.scheduler.start()
1255+
1256+
connection_kwargs.pop("use_cache", None)
1257+
connection_kwargs.pop("cache_size", None)
1258+
connection_kwargs.pop("cache_ttl", None)
1259+
connection_kwargs.pop("cache", None)
12371260

12381261
# a lock to protect the critical section in _checkpid().
12391262
# this lock is acquired when the process id changes, such as
@@ -1246,6 +1269,10 @@ def __init__(
12461269
self._fork_lock = threading.Lock()
12471270
self.reset()
12481271

1272+
def __del__(self):
1273+
if self._scheduler is not None:
1274+
self.scheduler.shutdown()
1275+
12491276
def __repr__(self) -> (str, str):
12501277
return (
12511278
f"<{type(self).__module__}.{type(self).__name__}"
@@ -1432,6 +1459,16 @@ def set_retry(self, retry: "Retry") -> None:
14321459
for conn in self._in_use_connections:
14331460
conn.retry = retry
14341461

1462+
def _perform_health_check(self) -> None:
1463+
self._checkpid()
1464+
with self._lock:
1465+
while self._available_connections:
1466+
conn = self._available_connections.pop()
1467+
self._in_use_connections.add(conn)
1468+
conn.send_command('PING')
1469+
conn.read_response()
1470+
self.release(conn)
1471+
14351472

14361473
class BlockingConnectionPool(ConnectionPool):
14371474
"""

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
async-timeout>=4.0.3
22
cachetools
3+
apscheduler

tests/test_cache.py

Lines changed: 177 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,190 @@
11
import time
22

3+
import pytest
4+
from cachetools import TTLCache, LRUCache, LFUCache
5+
6+
import redis
37
from redis import Redis, RedisCluster
8+
from redis.utils import HIREDIS_AVAILABLE
9+
from tests.conftest import _get_client
10+
11+
12+
@pytest.fixture()
13+
def r(request):
14+
use_cache = request.param.get("use_cache", False)
15+
cache = request.param.get("cache")
16+
kwargs = request.param.get("kwargs", {})
17+
protocol = request.param.get("protocol", 3)
18+
single_connection_client = request.param.get("single_connection_client", False)
19+
with _get_client(
20+
redis.Redis,
21+
request,
22+
protocol=protocol,
23+
single_connection_client=single_connection_client,
24+
use_cache=use_cache,
25+
cache=cache,
26+
**kwargs,
27+
) as client:
28+
yield client, cache
29+
30+
31+
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
32+
class TestCache:
33+
@pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True)
34+
@pytest.mark.onlynoncluster
35+
def test_get_from_cache(self, r, r2):
36+
r, cache = r
37+
# add key to redis
38+
r.set("foo", "bar")
39+
# get key from redis and save in local cache
40+
assert r.get("foo") == b"bar"
41+
# get key from local cache
42+
assert cache.get(("GET", "foo")) == b"bar"
43+
# change key in redis (cause invalidation)
44+
r2.set("foo", "barbar")
45+
# Retrieves a new value from server and cache it
46+
assert r.get("foo") == b"barbar"
47+
# Make sure that new value was cached
48+
assert cache.get(("GET", "foo")) == b"barbar"
49+
50+
@pytest.mark.parametrize(
51+
"r",
52+
[{"cache": LRUCache(3), "use_cache": True}],
53+
indirect=True,
54+
)
55+
def test_cache_lru_eviction(self, r):
56+
r, cache = r
57+
# add 3 keys to redis
58+
r.set("foo", "bar")
59+
r.set("foo2", "bar2")
60+
r.set("foo3", "bar3")
61+
# get 3 keys from redis and save in local cache
62+
assert r.get("foo") == b"bar"
63+
assert r.get("foo2") == b"bar2"
64+
assert r.get("foo3") == b"bar3"
65+
# get the 3 keys from local cache
66+
assert cache.get(("GET", "foo")) == b"bar"
67+
assert cache.get(("GET", "foo2")) == b"bar2"
68+
assert cache.get(("GET", "foo3")) == b"bar3"
69+
# add 1 more key to redis (exceed the max size)
70+
r.set("foo4", "bar4")
71+
assert r.get("foo4") == b"bar4"
72+
# the first key is not in the local cache anymore
73+
assert cache.get(("GET", "foo")) is None
74+
75+
@pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True)
76+
def test_cache_ttl(self, r, cache):
77+
r, cache = r
78+
# add key to redis
79+
r.set("foo", "bar")
80+
# get key from redis and save in local cache
81+
assert r.get("foo") == b"bar"
82+
# get key from local cache
83+
assert cache.get(("GET", "foo")) == b"bar"
84+
# wait for the key to expire
85+
time.sleep(1)
86+
# the key is not in the local cache anymore
87+
assert cache.get(("GET", "foo")) is None
488

89+
@pytest.mark.parametrize(
90+
"r",
91+
[{"cache": LFUCache(3), "use_cache": True}],
92+
indirect=True,
93+
)
94+
def test_cache_lfu_eviction(self, r, cache):
95+
r, cache = r
96+
# add 3 keys to redis
97+
r.set("foo", "bar")
98+
r.set("foo2", "bar2")
99+
r.set("foo3", "bar3")
100+
# get 3 keys from redis and save in local cache
101+
assert r.get("foo") == b"bar"
102+
assert r.get("foo2") == b"bar2"
103+
assert r.get("foo3") == b"bar3"
104+
# change the order of the keys in the cache
105+
assert cache.get(("GET", "foo")) == b"bar"
106+
assert cache.get(("GET", "foo")) == b"bar"
107+
assert cache.get(("GET", "foo3")) == b"bar3"
108+
# add 1 more key to redis (exceed the max size)
109+
r.set("foo4", "bar4")
110+
assert r.get("foo4") == b"bar4"
111+
# test the eviction policy
112+
assert cache.currsize == 3
113+
assert cache.get(("GET", "foo")) == b"bar"
114+
assert cache.get(("GET", "foo2")) is None
5115

6-
def test_standalone_cached_get_and_set():
7-
r = Redis(use_cache=True, protocol=3)
8-
assert r.set("key", 5)
9-
assert r.get("key") == b"5"
116+
@pytest.mark.parametrize(
117+
"r",
118+
[{"cache": LRUCache(maxsize=128), "use_cache": True}],
119+
indirect=True,
120+
)
121+
def test_cache_ignore_not_allowed_command(self, r):
122+
r, cache = r
123+
# add fields to hash
124+
assert r.hset("foo", "bar", "baz")
125+
# get random field
126+
assert r.hrandfield("foo") == b"bar"
127+
assert cache.get(("HRANDFIELD", "foo")) is None
10128

11-
r2 = Redis(protocol=3)
12-
r2.set("key", "foo")
129+
@pytest.mark.parametrize(
130+
"r",
131+
[{"cache": LRUCache(maxsize=128), "use_cache": True}],
132+
indirect=True,
133+
)
134+
def test_cache_invalidate_all_related_responses(self, r, cache):
135+
r, cache = r
136+
# Add keys
137+
assert r.set("foo", "bar")
138+
assert r.set("bar", "foo")
13139

14-
time.sleep(0.5)
140+
# Make sure that replies was cached
141+
assert r.mget("foo", "bar") == [b"bar", b"foo"]
142+
assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"]
15143

16-
after_invalidation = r.get("key")
17-
print(f'after invalidation {after_invalidation}')
18-
assert after_invalidation == b"foo"
144+
# Invalidate one of the keys and make sure that all associated cached entries was removed
145+
assert r.set("foo", "baz")
146+
assert r.get("foo") == b"baz"
147+
assert cache.get(("MGET", "foo", "bar")) is None
148+
assert cache.get(("GET", "foo")) == b"baz"
19149

150+
@pytest.mark.parametrize(
151+
"r",
152+
[{"cache": LRUCache(maxsize=128), "use_cache": True}],
153+
indirect=True,
154+
)
155+
def test_cache_flushed_on_server_flush(self, r, cache):
156+
r, cache = r
157+
# Add keys
158+
assert r.set("foo", "bar")
159+
assert r.set("bar", "foo")
160+
assert r.set("baz", "bar")
20161

21-
def test_cluster_cached_get_and_set():
22-
cluster_url = "redis://localhost:16379/0"
162+
# Make sure that replies was cached
163+
assert r.get("foo") == b"bar"
164+
assert r.get("bar") == b"foo"
165+
assert r.get("baz") == b"bar"
166+
assert cache.get(("GET", "foo")) == b"bar"
167+
assert cache.get(("GET", "bar")) == b"foo"
168+
assert cache.get(("GET", "baz")) == b"bar"
23169

24-
r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3)
25-
assert r.set("key", 5)
26-
assert r.get("key") == b"5"
170+
# Flush server and trying to access cached entry
171+
assert r.flushall()
172+
assert r.get("foo") is None
173+
assert cache.currsize == 0
27174

28-
r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3)
29-
r2.set("key", "foo")
30175

31-
time.sleep(0.5)
32-
33-
after_invalidation = r.get("key")
34-
print(f'after invalidation {after_invalidation}')
35-
assert after_invalidation == b"foo"
176+
# def test_cluster_cached_get_and_set():
177+
# cluster_url = "redis://localhost:16379/0"
178+
#
179+
# r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3)
180+
# assert r.set("key", 5)
181+
# assert r.get("key") == b"5"
182+
#
183+
# r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3)
184+
# r2.set("key", "foo")
185+
#
186+
# time.sleep(0.5)
187+
#
188+
# after_invalidation = r.get("key")
189+
# print(f'after invalidation {after_invalidation}')
190+
# assert after_invalidation == b"foo"

0 commit comments

Comments
 (0)