Skip to content

Commit 74499a9

Browse files
committed
Added custom background scheduler, added unit testing
1 parent 2572167 commit 74499a9

File tree

8 files changed

+290
-21
lines changed

8 files changed

+290
-21
lines changed

dev_requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
black==24.3.0
22
cachetools>=5.5.0
3-
apscheduler
43
click==8.0.4
54
flake8-isort
65
flake8

redis/connection.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from abc import abstractmethod
99
from itertools import chain
1010
from queue import Empty, Full, LifoQueue
11-
from time import time
11+
from time import time, sleep
1212
from typing import Any, Callable, List, Optional, Type, Union
1313
from urllib.parse import parse_qs, unquote, urlparse
1414

15-
from apscheduler.schedulers.background import BackgroundScheduler
1615
from cachetools import LRUCache
1716
from cachetools.keys import hashkey
1817
from redis.cache import (
@@ -21,6 +20,7 @@
2120
CacheInterface,
2221
CacheToolsFactory,
2322
)
23+
from . import scheduler
2424

2525
from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
2626
from .backoff import NoBackoff
@@ -36,6 +36,7 @@
3636
TimeoutError,
3737
)
3838
from .retry import Retry
39+
from .scheduler import Scheduler
3940
from .utils import (
4041
CRYPTOGRAPHY_AVAILABLE,
4142
HIREDIS_AVAILABLE,
@@ -1265,7 +1266,9 @@ def __init__(
12651266
self.cache = None
12661267
self._cache_conf = None
12671268
self._cache_factory = cache_factory
1268-
self.scheduler = None
1269+
self._scheduler = None
1270+
self._hc_cancel_event = None
1271+
self._hc_thread = None
12691272

12701273
if connection_kwargs.get("use_cache"):
12711274
if connection_kwargs.get("protocol") not in [3, "3"]:
@@ -1286,14 +1289,7 @@ def __init__(
12861289
else:
12871290
self.cache = CacheToolsFactory(self._cache_conf).get_cache()
12881291

1289-
self.scheduler = BackgroundScheduler()
1290-
self.scheduler.add_job(
1291-
self._perform_health_check,
1292-
"interval",
1293-
seconds=2,
1294-
id="cache_health_check",
1295-
)
1296-
self.scheduler.start()
1292+
self._scheduler = Scheduler()
12971293

12981294
connection_kwargs.pop("use_cache", None)
12991295
connection_kwargs.pop("cache_eviction", None)
@@ -1312,6 +1308,16 @@ def __init__(
13121308
self._fork_lock = threading.Lock()
13131309
self.reset()
13141310

1311+
# Run scheduled healthcheck to avoid stale invalidations in idle connections.
1312+
if self.cache is not None and self._scheduler is not None:
1313+
self._hc_cancel_event = threading.Event()
1314+
self._hc_thread = self._scheduler.run_with_interval(
1315+
self._perform_health_check,
1316+
2,
1317+
self._hc_cancel_event
1318+
)
1319+
1320+
13151321
def __repr__(self) -> (str, str):
13161322
return (
13171323
f"<{type(self).__module__}.{type(self).__name__}"
@@ -1491,6 +1497,14 @@ def disconnect(self, inuse_connections: bool = True) -> None:
14911497
for connection in connections:
14921498
connection.disconnect()
14931499

1500+
# Send an event to stop scheduled healthcheck execution.
1501+
if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set():
1502+
self._hc_cancel_event.set()
1503+
1504+
# Joins healthcheck thread on disconnect.
1505+
if self._hc_thread is not None and not self._hc_thread.is_alive():
1506+
self._hc_thread.join()
1507+
14941508
def close(self) -> None:
14951509
"""Close the pool, disconnecting all connections"""
14961510
self.disconnect()
@@ -1502,13 +1516,14 @@ def set_retry(self, retry: "Retry") -> None:
15021516
for conn in self._in_use_connections:
15031517
conn.retry = retry
15041518

1505-
def _perform_health_check(self) -> None:
1519+
def _perform_health_check(self, done: threading.Event) -> None:
15061520
self._checkpid()
15071521
with self._lock:
15081522
while self._available_connections:
15091523
conn = self._available_connections.pop()
15101524
conn.send_command("PING")
15111525
conn.read_response()
1526+
done.set()
15121527

15131528

15141529
class BlockingConnectionPool(ConnectionPool):

redis/scheduler.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import threading
2+
import time
3+
from typing import Callable
4+
5+
6+
class Scheduler:
7+
8+
def __init__(self, polling_period: float = 0.1):
9+
"""
10+
:param polling_period: Period between polling operations.
11+
Needs to detect when new job has to be scheduled.
12+
"""
13+
self.polling_period = polling_period
14+
15+
def run_with_interval(
16+
self,
17+
func: Callable[[threading.Event, ...], None],
18+
interval: float,
19+
cancel: threading.Event,
20+
args: tuple = (),
21+
) -> threading.Thread:
22+
"""
23+
Run scheduled execution with given interval
24+
in a separate thread until cancel event won't be set.
25+
"""
26+
done = threading.Event()
27+
thread = threading.Thread(target=self._run_timer, args=(func, interval, (done, *args), done, cancel))
28+
thread.start()
29+
return thread
30+
31+
def _get_timer(
32+
self,
33+
func: Callable[[threading.Event, ...], None],
34+
interval: float,
35+
args: tuple
36+
) -> threading.Timer:
37+
timer = threading.Timer(interval=interval, function=func, args=args)
38+
return timer
39+
40+
def _run_timer(
41+
self,
42+
func: Callable[[threading.Event, ...], None],
43+
interval: float,
44+
args: tuple,
45+
done: threading.Event,
46+
cancel: threading.Event
47+
):
48+
timer = self._get_timer(func, interval, args)
49+
timer.start()
50+
51+
while not cancel.is_set():
52+
if done.is_set():
53+
done.clear()
54+
timer.join()
55+
timer = self._get_timer(func, interval, args)
56+
timer.start()
57+
else:
58+
time.sleep(self.polling_period)
59+
60+
timer.cancel()
61+
timer.join()

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
async-timeout>=4.0.3
2-
cachetools>=5.5.0
3-
apscheduler
2+
cachetools>=5.5.0

tests/conftest.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import argparse
22
import random
3+
import threading
34
import time
45
from typing import Callable, TypeVar
56
from unittest import mock
67
from unittest.mock import Mock
78
from urllib.parse import urlparse
89

910
import pytest
11+
from _pytest import unittest
12+
1013
import redis
1114
from packaging.version import Version
1215
from redis import Sentinel
1316
from redis.backoff import NoBackoff
14-
from redis.connection import Connection, SSLConnection, parse_url
17+
from redis.cache import CacheConfiguration, EvictionPolicy, CacheFactoryInterface, CacheInterface
18+
from redis.connection import Connection, SSLConnection, parse_url, ConnectionPool, ConnectionInterface
1519
from redis.exceptions import RedisClusterException
1620
from redis.retry import Retry
1721
from tests.ssl_utils import get_ssl_filename
@@ -537,6 +541,33 @@ def master_host(request):
537541
return parts.hostname, (parts.port or 6379)
538542

539543

544+
@pytest.fixture()
545+
def cache_conf() -> CacheConfiguration:
546+
return CacheConfiguration(
547+
cache_size=100,
548+
cache_ttl=20,
549+
cache_eviction=EvictionPolicy.TTL
550+
)
551+
552+
553+
@pytest.fixture()
554+
def mock_cache_factory() -> CacheFactoryInterface:
555+
mock_factory = Mock(spec=CacheFactoryInterface)
556+
return mock_factory
557+
558+
559+
@pytest.fixture()
560+
def mock_cache() -> CacheInterface:
561+
mock_cache = Mock(spec=CacheInterface)
562+
return mock_cache
563+
564+
565+
@pytest.fixture()
566+
def mock_connection() -> ConnectionInterface:
567+
mock_connection = Mock(spec=ConnectionInterface)
568+
return mock_connection
569+
570+
540571
def wait_for_command(client, monitor, command, key=None):
541572
# issue a command with a key name that's local to this process.
542573
# if we find a command with our key before the command we're waiting

tests/test_cache.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import redis
66
from cachetools import LFUCache, LRUCache, TTLCache
7-
from redis.cache import CacheToolsAdapter, EvictionPolicy
7+
from redis.cache import CacheToolsAdapter, EvictionPolicy, CacheConfiguration
88
from redis.utils import HIREDIS_AVAILABLE
99
from tests.conftest import _get_client, skip_if_resp_version
1010

@@ -185,7 +185,7 @@ def test_get_from_cache_multithreaded(self, r):
185185
indirect=True,
186186
)
187187
@pytest.mark.onlynoncluster
188-
def test_health_check_invalidate_cache(self, r, r2):
188+
def test_health_check_invalidate_cache(self, r):
189189
cache = r.get_cache()
190190
# add key to redis
191191
r.set("foo", "bar")
@@ -194,7 +194,7 @@ def test_health_check_invalidate_cache(self, r, r2):
194194
# get key from local cache
195195
assert cache.get(("GET", "foo")) == b"bar"
196196
# change key in redis (cause invalidation)
197-
r2.set("foo", "barbar")
197+
r.set("foo", "barbar")
198198
# Wait for health check
199199
time.sleep(2)
200200
# Make sure that value was invalidated
@@ -1154,3 +1154,26 @@ def test_cache_invalidate_all_related_responses(self, r):
11541154
assert r.get("foo") == b"baz"
11551155
assert cache.get(("MGET", "foo", "bar")) is None
11561156
assert cache.get(("GET", "foo")) == b"baz"
1157+
1158+
1159+
class TestUnitCacheConfiguration:
1160+
TTL = 20
1161+
MAX_SIZE = 100
1162+
EVICTION_POLICY = EvictionPolicy.TTL
1163+
1164+
def test_get_ttl(self, cache_conf: CacheConfiguration):
1165+
assert self.TTL == cache_conf.get_ttl()
1166+
1167+
def test_get_max_size(self, cache_conf: CacheConfiguration):
1168+
assert self.MAX_SIZE == cache_conf.get_max_size()
1169+
1170+
def test_get_eviction_policy(self, cache_conf: CacheConfiguration):
1171+
assert self.EVICTION_POLICY == cache_conf.get_eviction_policy()
1172+
1173+
def test_is_exceeds_max_size(self, cache_conf: CacheConfiguration):
1174+
assert not cache_conf.is_exceeds_max_size(self.MAX_SIZE)
1175+
assert cache_conf.is_exceeds_max_size(self.MAX_SIZE + 1)
1176+
1177+
def test_is_allowed_to_cache(self, cache_conf: CacheConfiguration):
1178+
assert cache_conf.is_allowed_to_cache("GET")
1179+
assert not cache_conf.is_allowed_to_cache("SET")

0 commit comments

Comments
 (0)