Skip to content

Commit 838682d

Browse files
Merge branch 'master' into fix_unique_cache_key
2 parents 376eda7 + b102c3b commit 838682d

File tree

12 files changed

+444
-55
lines changed

12 files changed

+444
-55
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@
8181
)
8282

8383
if TYPE_CHECKING and SSL_AVAILABLE:
84-
from ssl import TLSVersion, VerifyMode
84+
from ssl import TLSVersion, VerifyFlags, VerifyMode
8585
else:
8686
TLSVersion = None
8787
VerifyMode = None
88+
VerifyFlags = None
8889

8990
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
9091
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -238,6 +239,8 @@ def __init__(
238239
ssl_keyfile: Optional[str] = None,
239240
ssl_certfile: Optional[str] = None,
240241
ssl_cert_reqs: Union[str, VerifyMode] = "required",
242+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
241244
ssl_ca_certs: Optional[str] = None,
242245
ssl_ca_data: Optional[str] = None,
243246
ssl_check_hostname: bool = True,
@@ -347,6 +350,8 @@ def __init__(
347350
"ssl_keyfile": ssl_keyfile,
348351
"ssl_certfile": ssl_certfile,
349352
"ssl_cert_reqs": ssl_cert_reqs,
353+
"ssl_include_verify_flags": ssl_include_verify_flags,
354+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
350355
"ssl_ca_certs": ssl_ca_certs,
351356
"ssl_ca_data": ssl_ca_data,
352357
"ssl_check_hostname": ssl_check_hostname,

redis/asyncio/cluster.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@
8686
)
8787

8888
if SSL_AVAILABLE:
89-
from ssl import TLSVersion, VerifyMode
89+
from ssl import TLSVersion, VerifyFlags, VerifyMode
9090
else:
9191
TLSVersion = None
9292
VerifyMode = None
93+
VerifyFlags = None
9394

9495
TargetNodesT = TypeVar(
9596
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -299,6 +300,8 @@ def __init__(
299300
ssl_ca_certs: Optional[str] = None,
300301
ssl_ca_data: Optional[str] = None,
301302
ssl_cert_reqs: Union[str, VerifyMode] = "required",
303+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
304+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
302305
ssl_certfile: Optional[str] = None,
303306
ssl_check_hostname: bool = True,
304307
ssl_keyfile: Optional[str] = None,
@@ -358,6 +361,8 @@ def __init__(
358361
"ssl_ca_certs": ssl_ca_certs,
359362
"ssl_ca_data": ssl_ca_data,
360363
"ssl_cert_reqs": ssl_cert_reqs,
364+
"ssl_include_verify_flags": ssl_include_verify_flags,
365+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
361366
"ssl_certfile": ssl_certfile,
362367
"ssl_check_hostname": ssl_check_hostname,
363368
"ssl_keyfile": ssl_keyfile,

redis/asyncio/connection.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
if SSL_AVAILABLE:
3232
import ssl
33-
from ssl import SSLContext, TLSVersion
33+
from ssl import SSLContext, TLSVersion, VerifyFlags
3434
else:
3535
ssl = None
3636
TLSVersion = None
3737
SSLContext = None
38+
VerifyFlags = None
3839

3940
from ..auth.token import TokenInterface
4041
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
@@ -793,6 +794,8 @@ def __init__(
793794
ssl_keyfile: Optional[str] = None,
794795
ssl_certfile: Optional[str] = None,
795796
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
797+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
798+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
796799
ssl_ca_certs: Optional[str] = None,
797800
ssl_ca_data: Optional[str] = None,
798801
ssl_check_hostname: bool = True,
@@ -807,6 +810,8 @@ def __init__(
807810
keyfile=ssl_keyfile,
808811
certfile=ssl_certfile,
809812
cert_reqs=ssl_cert_reqs,
813+
include_verify_flags=ssl_include_verify_flags,
814+
exclude_verify_flags=ssl_exclude_verify_flags,
810815
ca_certs=ssl_ca_certs,
811816
ca_data=ssl_ca_data,
812817
check_hostname=ssl_check_hostname,
@@ -832,6 +837,14 @@ def certfile(self):
832837
def cert_reqs(self):
833838
return self.ssl_context.cert_reqs
834839

840+
@property
841+
def include_verify_flags(self):
842+
return self.ssl_context.include_verify_flags
843+
844+
@property
845+
def exclude_verify_flags(self):
846+
return self.ssl_context.exclude_verify_flags
847+
835848
@property
836849
def ca_certs(self):
837850
return self.ssl_context.ca_certs
@@ -854,6 +867,8 @@ class RedisSSLContext:
854867
"keyfile",
855868
"certfile",
856869
"cert_reqs",
870+
"include_verify_flags",
871+
"exclude_verify_flags",
857872
"ca_certs",
858873
"ca_data",
859874
"context",
@@ -867,6 +882,8 @@ def __init__(
867882
keyfile: Optional[str] = None,
868883
certfile: Optional[str] = None,
869884
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
885+
include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
886+
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
870887
ca_certs: Optional[str] = None,
871888
ca_data: Optional[str] = None,
872889
check_hostname: bool = False,
@@ -892,6 +909,8 @@ def __init__(
892909
)
893910
cert_reqs = CERT_REQS[cert_reqs]
894911
self.cert_reqs = cert_reqs
912+
self.include_verify_flags = include_verify_flags
913+
self.exclude_verify_flags = exclude_verify_flags
895914
self.ca_certs = ca_certs
896915
self.ca_data = ca_data
897916
self.check_hostname = (
@@ -906,6 +925,12 @@ def get(self) -> SSLContext:
906925
context = ssl.create_default_context()
907926
context.check_hostname = self.check_hostname
908927
context.verify_mode = self.cert_reqs
928+
if self.include_verify_flags:
929+
for flag in self.include_verify_flags:
930+
context.verify_flags |= flag
931+
if self.exclude_verify_flags:
932+
for flag in self.exclude_verify_flags:
933+
context.verify_flags &= ~flag
909934
if self.certfile and self.keyfile:
910935
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
911936
if self.ca_certs or self.ca_data:
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
953978
return bool(value)
954979

955980

981+
def parse_ssl_verify_flags(value):
982+
# flags are passed in as a string representation of a list,
983+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984+
verify_flags_str = value.replace("[", "").replace("]", "")
985+
986+
verify_flags = []
987+
for flag in verify_flags_str.split(","):
988+
flag = flag.strip()
989+
if not hasattr(VerifyFlags, flag):
990+
raise ValueError(f"Invalid ssl verify flag: {flag}")
991+
verify_flags.append(getattr(VerifyFlags, flag))
992+
return verify_flags
993+
994+
956995
URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
957996
{
958997
"db": int,
@@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
9631002
"max_connections": int,
9641003
"health_check_interval": int,
9651004
"ssl_check_hostname": to_bool,
1005+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1006+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
9661007
"timeout": float,
9671008
}
9681009
)
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
10211062

10221063
if parsed.scheme == "rediss":
10231064
kwargs["connection_class"] = SSLConnection
1065+
10241066
else:
10251067
valid_schemes = "redis://, rediss://, unix://"
10261068
raise ValueError(

redis/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def __init__(
224224
ssl_keyfile: Optional[str] = None,
225225
ssl_certfile: Optional[str] = None,
226226
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
227+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
228+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
227229
ssl_ca_certs: Optional[str] = None,
228230
ssl_ca_path: Optional[str] = None,
229231
ssl_ca_data: Optional[str] = None,
@@ -330,6 +332,8 @@ def __init__(
330332
"ssl_keyfile": ssl_keyfile,
331333
"ssl_certfile": ssl_certfile,
332334
"ssl_cert_reqs": ssl_cert_reqs,
335+
"ssl_include_verify_flags": ssl_include_verify_flags,
336+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
333337
"ssl_ca_certs": ssl_ca_certs,
334338
"ssl_ca_data": ssl_ca_data,
335339
"ssl_check_hostname": ssl_check_hostname,

redis/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def parse_cluster_myshardid(resp, **options):
184184
"ssl_ca_data",
185185
"ssl_certfile",
186186
"ssl_cert_reqs",
187+
"ssl_include_verify_flags",
188+
"ssl_exclude_verify_flags",
187189
"ssl_keyfile",
188190
"ssl_password",
189191
"ssl_check_hostname",

redis/connection.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@
6868

6969
if SSL_AVAILABLE:
7070
import ssl
71+
from ssl import VerifyFlags
7172
else:
7273
ssl = None
74+
VerifyFlags = None
7375

7476
if HIREDIS_AVAILABLE:
7577
import hiredis
@@ -686,8 +688,12 @@ def on_connect_check_health(self, check_health: bool = True):
686688
):
687689
raise ConnectionError("Invalid RESP version")
688690

689-
# Send maintenance notifications handshake if RESP3 is active and maintenance notifications are enabled
691+
# Send maintenance notifications handshake if RESP3 is active
692+
# and maintenance notifications are enabled
690693
# and we have a host to determine the endpoint type from
694+
# When the maint_notifications_config enabled mode is "auto",
695+
# we just log a warning if the handshake fails
696+
# When the mode is enabled=True, we raise an exception in case of failure
691697
if (
692698
self.protocol not in [2, "2"]
693699
and self.maint_notifications_config
@@ -709,15 +715,21 @@ def on_connect_check_health(self, check_health: bool = True):
709715
)
710716
response = self.read_response()
711717
if str_if_bytes(response) != "OK":
712-
raise ConnectionError(
718+
raise ResponseError(
713719
"The server doesn't support maintenance notifications"
714720
)
715721
except Exception as e:
716-
# Log warning but don't fail the connection
717-
import logging
722+
if (
723+
isinstance(e, ResponseError)
724+
and self.maint_notifications_config.enabled == "auto"
725+
):
726+
# Log warning but don't fail the connection
727+
import logging
718728

719-
logger = logging.getLogger(__name__)
720-
logger.warning(f"Failed to enable maintenance notifications: {e}")
729+
logger = logging.getLogger(__name__)
730+
logger.warning(f"Failed to enable maintenance notifications: {e}")
731+
else:
732+
raise
721733

722734
# if a client_name is given, set it
723735
if self.client_name:
@@ -1362,6 +1374,8 @@ def __init__(
13621374
ssl_keyfile=None,
13631375
ssl_certfile=None,
13641376
ssl_cert_reqs="required",
1377+
ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1378+
ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
13651379
ssl_ca_certs=None,
13661380
ssl_ca_data=None,
13671381
ssl_check_hostname=True,
@@ -1380,7 +1394,10 @@ def __init__(
13801394
Args:
13811395
ssl_keyfile: Path to an ssl private key. Defaults to None.
13821396
ssl_certfile: Path to an ssl certificate. Defaults to None.
1383-
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1397+
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1398+
or an ssl.VerifyMode. Defaults to "required".
1399+
ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1400+
ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
13841401
ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
13851402
ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
13861403
ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
@@ -1416,6 +1433,8 @@ def __init__(
14161433
)
14171434
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
14181435
self.cert_reqs = ssl_cert_reqs
1436+
self.ssl_include_verify_flags = ssl_include_verify_flags
1437+
self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
14191438
self.ca_certs = ssl_ca_certs
14201439
self.ca_data = ssl_ca_data
14211440
self.ca_path = ssl_ca_path
@@ -1455,6 +1474,12 @@ def _wrap_socket_with_ssl(self, sock):
14551474
context = ssl.create_default_context()
14561475
context.check_hostname = self.check_hostname
14571476
context.verify_mode = self.cert_reqs
1477+
if self.ssl_include_verify_flags:
1478+
for flag in self.ssl_include_verify_flags:
1479+
context.verify_flags |= flag
1480+
if self.ssl_exclude_verify_flags:
1481+
for flag in self.ssl_exclude_verify_flags:
1482+
context.verify_flags &= ~flag
14581483
if self.certfile or self.keyfile:
14591484
context.load_cert_chain(
14601485
certfile=self.certfile,
@@ -1568,6 +1593,20 @@ def to_bool(value):
15681593
return bool(value)
15691594

15701595

1596+
def parse_ssl_verify_flags(value):
1597+
# flags are passed in as a string representation of a list,
1598+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1599+
verify_flags_str = value.replace("[", "").replace("]", "")
1600+
1601+
verify_flags = []
1602+
for flag in verify_flags_str.split(","):
1603+
flag = flag.strip()
1604+
if not hasattr(VerifyFlags, flag):
1605+
raise ValueError(f"Invalid ssl verify flag: {flag}")
1606+
verify_flags.append(getattr(VerifyFlags, flag))
1607+
return verify_flags
1608+
1609+
15711610
URL_QUERY_ARGUMENT_PARSERS = {
15721611
"db": int,
15731612
"socket_timeout": float,
@@ -1578,6 +1617,8 @@ def to_bool(value):
15781617
"max_connections": int,
15791618
"health_check_interval": int,
15801619
"ssl_check_hostname": to_bool,
1620+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1621+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
15811622
"timeout": float,
15821623
}
15831624

redis/maint_notifications.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
from abc import ABC, abstractmethod
8-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import TYPE_CHECKING, Literal, Optional, Union
99

1010
from redis.typing import Number
1111

@@ -447,7 +447,7 @@ class MaintNotificationsConfig:
447447

448448
def __init__(
449449
self,
450-
enabled: bool = True,
450+
enabled: Union[bool, Literal["auto"]] = "auto",
451451
proactive_reconnect: bool = True,
452452
relaxed_timeout: Optional[Number] = 10,
453453
endpoint_type: Optional[EndpointType] = None,
@@ -456,8 +456,13 @@ def __init__(
456456
Initialize a new MaintNotificationsConfig.
457457
458458
Args:
459-
enabled (bool): Whether to enable maintenance notifications handling.
460-
Defaults to False.
459+
enabled (bool | "auto"): Controls maintenance notifications handling behavior.
460+
- True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
461+
otherwise a ResponseError is raised.
462+
- "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
463+
gracefully handled - a warning is logged and normal operation continues.
464+
- False: Maintenance notifications are completely disabled.
465+
Defaults to "auto".
461466
proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
462467
Defaults to True.
463468
relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.

0 commit comments

Comments
 (0)