Skip to content

Commit 97ebebf

Browse files
committed
Added version restrictions
1 parent fd361a7 commit 97ebebf

File tree

4 files changed

+94
-7
lines changed

4 files changed

+94
-7
lines changed

redis/connection.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import chain
1010
from queue import Empty, Full, LifoQueue
1111
from time import time
12-
from typing import Any, Callable, List, Optional, Type, Union
12+
from typing import Any, Callable, List, Optional, Type, Union, Dict
1313
from urllib.parse import parse_qs, unquote, urlparse
1414

1515
from redis.cache import (
@@ -43,7 +43,7 @@
4343
SSL_AVAILABLE,
4444
format_error_message,
4545
get_lib_version,
46-
str_if_bytes,
46+
str_if_bytes, compare_versions,
4747
)
4848

4949
if HIREDIS_AVAILABLE:
@@ -197,6 +197,11 @@ def pack_command(self, *args):
197197
def pack_commands(self, commands):
198198
pass
199199

200+
@property
201+
@abstractmethod
202+
def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
203+
pass
204+
200205

201206
class AbstractConnection(ConnectionInterface):
202207
"Manages communication to and from a Redis server"
@@ -272,6 +277,7 @@ def __init__(
272277
self.next_health_check = 0
273278
self.redis_connect_func = redis_connect_func
274279
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
280+
self.handshake_metadata = None
275281
self._sock = None
276282
self._socket_read_size = socket_read_size
277283
self.set_parser(parser_class)
@@ -414,7 +420,7 @@ def on_connect(self):
414420
if len(auth_args) == 1:
415421
auth_args = ["default", auth_args[0]]
416422
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
417-
response = self.read_response()
423+
self.handshake_metadata = self.read_response()
418424
# if response.get(b"proto") != self.protocol and response.get(
419425
# "proto"
420426
# ) != self.protocol:
@@ -445,10 +451,10 @@ def on_connect(self):
445451
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
446452
self._parser.on_connect(self)
447453
self.send_command("HELLO", self.protocol)
448-
response = self.read_response()
454+
self.handshake_metadata = self.read_response()
449455
if (
450-
response.get(b"proto") != self.protocol
451-
and response.get("proto") != self.protocol
456+
self.handshake_metadata.get(b"proto") != self.protocol
457+
and self.handshake_metadata.get("proto") != self.protocol
452458
):
453459
raise ConnectionError("Invalid RESP version")
454460

@@ -649,6 +655,14 @@ def pack_commands(self, commands):
649655
def get_protocol(self) -> int or str:
650656
return self.protocol
651657

658+
@property
659+
def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
660+
return self._handshake_metadata
661+
662+
@handshake_metadata.setter
663+
def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
664+
self._handshake_metadata = value
665+
652666

653667
class Connection(AbstractConnection):
654668
"Manages TCP communication to and from a Redis server"
@@ -731,6 +745,7 @@ def ensure_string(key):
731745

732746
class CacheProxyConnection(ConnectionInterface):
733747
DUMMY_CACHE_VALUE = b"foo"
748+
MIN_ALLOWED_VERSION = '7.4.0'
734749

735750
def __init__(self, conn: ConnectionInterface, cache: CacheInterface):
736751
self.pid = os.getpid()
@@ -759,6 +774,17 @@ def set_parser(self, parser_class):
759774
def connect(self):
760775
self._conn.connect()
761776

777+
server_ver = self._conn.handshake_metadata.get(b"version", None)
778+
if server_ver is None:
779+
raise ConnectionError("Cannot retrieve information about server version")
780+
781+
server_ver = server_ver.decode("utf-8")
782+
783+
if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1:
784+
raise ConnectionError(
785+
"Server version does not satisfies a minimal requirement for client-side caching"
786+
)
787+
762788
def on_connect(self):
763789
self._conn.on_connect()
764790

redis/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,33 @@ def format_error_message(host_error: str, exception: BaseException) -> str:
153153
f"Error {exception.args[0]} connecting to {host_error}. "
154154
f"{exception.args[1]}."
155155
)
156+
157+
158+
def compare_versions(version1: str, version2: str) -> int:
159+
"""
160+
Compare two versions.
161+
162+
:return: -1 if version1 > version2
163+
0 if both versions are equal
164+
1 if version1 < version2
165+
"""
166+
167+
num_versions1 = list(map(int, version1.split(".")))
168+
num_versions2 = list(map(int, version2.split(".")))
169+
170+
if len(num_versions1) > len(num_versions2):
171+
diff = len(num_versions1) - len(num_versions2)
172+
for _ in range(diff):
173+
num_versions2.append(0)
174+
elif len(num_versions1) < len(num_versions2):
175+
diff = len(num_versions2) - len(num_versions1)
176+
for _ in range(diff):
177+
num_versions1.append(0)
178+
179+
for i, ver in enumerate(num_versions1):
180+
if num_versions1[i] > num_versions2[i]:
181+
return -1
182+
elif num_versions1[i] < num_versions2[i]:
183+
return 1
184+
185+
return 0

tests/test_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
LRUPolicy,
1414
)
1515
from redis.utils import HIREDIS_AVAILABLE
16-
from tests.conftest import _get_client, skip_if_resp_version
16+
from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt
1717

1818

1919
@pytest.fixture()
@@ -40,6 +40,7 @@ def r(request):
4040
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
4141
@pytest.mark.onlynoncluster
4242
@skip_if_resp_version(2)
43+
@skip_if_server_version_lt("7.4.0")
4344
class TestCache:
4445
@pytest.mark.parametrize(
4546
"r",
@@ -343,6 +344,7 @@ def test_cache_flushed_on_server_flush(self, r):
343344
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
344345
@pytest.mark.onlycluster
345346
@skip_if_resp_version(2)
347+
@skip_if_server_version_lt("7.4.0")
346348
class TestClusterCache:
347349
@pytest.mark.parametrize(
348350
"r",
@@ -605,6 +607,7 @@ def test_cache_flushed_on_server_flush(self, r, cache):
605607
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
606608
@pytest.mark.onlynoncluster
607609
@skip_if_resp_version(2)
610+
@skip_if_server_version_lt("7.4.0")
608611
class TestSentinelCache:
609612
@pytest.mark.parametrize(
610613
"sentinel_setup",
@@ -729,6 +732,7 @@ def test_cache_clears_on_disconnect(self, master, cache):
729732
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
730733
@pytest.mark.onlynoncluster
731734
@skip_if_resp_version(2)
735+
@skip_if_server_version_lt("7.4.0")
732736
class TestSSLCache:
733737
@pytest.mark.parametrize(
734738
"r",

tests/test_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
from redis.utils import compare_versions
3+
4+
5+
@pytest.mark.parametrize(
6+
"version1,version2,expected_res",
7+
[
8+
("1.0.0", "0.9.0", -1),
9+
("1.0.0", "1.0.0", 0),
10+
("0.9.0", "1.0.0", 1),
11+
("1.09.0", "1.9.0", 0),
12+
("1.090.0", "1.9.0", -1),
13+
("1", "0.9.0", -1),
14+
("1", "1.0.0", 0),
15+
],
16+
ids=[
17+
"version1 > version2",
18+
"version1 == version2",
19+
"version1 < version2",
20+
"version1 == version2 - different minor format",
21+
"version1 > version2 - different minor format",
22+
"version1 > version2 - major version only",
23+
"version1 == version2 - major version only",
24+
],
25+
)
26+
def test_compare_versions(version1, version2, expected_res):
27+
assert compare_versions(version1, version2) == expected_res

0 commit comments

Comments
 (0)