|
9 | 9 | from itertools import chain
|
10 | 10 | from queue import Empty, Full, LifoQueue
|
11 | 11 | from time import time
|
12 |
| -from typing import Any, Callable, List, Optional, Type, Union |
| 12 | +from typing import Any, Callable, List, Optional, Type, Union, Dict |
13 | 13 | from urllib.parse import parse_qs, unquote, urlparse
|
14 | 14 |
|
15 | 15 | from redis.cache import (
|
|
43 | 43 | SSL_AVAILABLE,
|
44 | 44 | format_error_message,
|
45 | 45 | get_lib_version,
|
46 |
| - str_if_bytes, |
| 46 | + str_if_bytes, compare_versions, |
47 | 47 | )
|
48 | 48 |
|
49 | 49 | if HIREDIS_AVAILABLE:
|
@@ -197,6 +197,11 @@ def pack_command(self, *args):
|
197 | 197 | def pack_commands(self, commands):
|
198 | 198 | pass
|
199 | 199 |
|
| 200 | + @property |
| 201 | + @abstractmethod |
| 202 | + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: |
| 203 | + pass |
| 204 | + |
200 | 205 |
|
201 | 206 | class AbstractConnection(ConnectionInterface):
|
202 | 207 | "Manages communication to and from a Redis server"
|
@@ -272,6 +277,7 @@ def __init__(
|
272 | 277 | self.next_health_check = 0
|
273 | 278 | self.redis_connect_func = redis_connect_func
|
274 | 279 | self.encoder = Encoder(encoding, encoding_errors, decode_responses)
|
| 280 | + self.handshake_metadata = None |
275 | 281 | self._sock = None
|
276 | 282 | self._socket_read_size = socket_read_size
|
277 | 283 | self.set_parser(parser_class)
|
@@ -414,7 +420,7 @@ def on_connect(self):
|
414 | 420 | if len(auth_args) == 1:
|
415 | 421 | auth_args = ["default", auth_args[0]]
|
416 | 422 | self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
|
417 |
| - response = self.read_response() |
| 423 | + self.handshake_metadata = self.read_response() |
418 | 424 | # if response.get(b"proto") != self.protocol and response.get(
|
419 | 425 | # "proto"
|
420 | 426 | # ) != self.protocol:
|
@@ -445,10 +451,10 @@ def on_connect(self):
|
445 | 451 | self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
|
446 | 452 | self._parser.on_connect(self)
|
447 | 453 | self.send_command("HELLO", self.protocol)
|
448 |
| - response = self.read_response() |
| 454 | + self.handshake_metadata = self.read_response() |
449 | 455 | if (
|
450 |
| - response.get(b"proto") != self.protocol |
451 |
| - and response.get("proto") != self.protocol |
| 456 | + self.handshake_metadata.get(b"proto") != self.protocol |
| 457 | + and self.handshake_metadata.get("proto") != self.protocol |
452 | 458 | ):
|
453 | 459 | raise ConnectionError("Invalid RESP version")
|
454 | 460 |
|
@@ -649,6 +655,14 @@ def pack_commands(self, commands):
|
649 | 655 | def get_protocol(self) -> int or str:
|
650 | 656 | return self.protocol
|
651 | 657 |
|
| 658 | + @property |
| 659 | + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: |
| 660 | + return self._handshake_metadata |
| 661 | + |
| 662 | + @handshake_metadata.setter |
| 663 | + def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): |
| 664 | + self._handshake_metadata = value |
| 665 | + |
652 | 666 |
|
653 | 667 | class Connection(AbstractConnection):
|
654 | 668 | "Manages TCP communication to and from a Redis server"
|
@@ -731,6 +745,7 @@ def ensure_string(key):
|
731 | 745 |
|
732 | 746 | class CacheProxyConnection(ConnectionInterface):
|
733 | 747 | DUMMY_CACHE_VALUE = b"foo"
|
| 748 | + MIN_ALLOWED_VERSION = '7.4.0' |
734 | 749 |
|
735 | 750 | def __init__(self, conn: ConnectionInterface, cache: CacheInterface):
|
736 | 751 | self.pid = os.getpid()
|
@@ -759,6 +774,17 @@ def set_parser(self, parser_class):
|
759 | 774 | def connect(self):
|
760 | 775 | self._conn.connect()
|
761 | 776 |
|
| 777 | + server_ver = self._conn.handshake_metadata.get(b"version", None) |
| 778 | + if server_ver is None: |
| 779 | + raise ConnectionError("Cannot retrieve information about server version") |
| 780 | + |
| 781 | + server_ver = server_ver.decode("utf-8") |
| 782 | + |
| 783 | + if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1: |
| 784 | + raise ConnectionError( |
| 785 | + "Server version does not satisfies a minimal requirement for client-side caching" |
| 786 | + ) |
| 787 | + |
762 | 788 | def on_connect(self):
|
763 | 789 | self._conn.on_connect()
|
764 | 790 |
|
|
0 commit comments