Skip to content

Commit 1a51cb4

Browse files
committed
Stabilize more tests and fix linters
1 parent 9ddbd87 commit 1a51cb4

File tree

6 files changed

+41
-41
lines changed

6 files changed

+41
-41
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ exclude =
1616
ignore =
1717
E126
1818
E203
19+
E231
1920
E701
2021
E704
2122
F405

redis/asyncio/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ def parse_url(url: str) -> ConnectKwargs:
10341034
try:
10351035
kwargs[name] = parser(value)
10361036
except (TypeError, ValueError):
1037-
raise ValueError(f"Invalid value for `{name}` in connection URL.")
1037+
raise ValueError(f"Invalid value for '{name}' in connection URL.")
10381038
else:
10391039
kwargs[name] = value
10401040

redis/commands/graph/query_result.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import sys
22
from collections import OrderedDict
3-
from distutils.util import strtobool
43

54
# from prettytable import PrettyTable
65
from redis import ResponseError
@@ -571,3 +570,19 @@ async def parse_array(self, value):
571570
"""
572571
scalar = [await self.parse_scalar(value[i]) for i in range(len(value))]
573572
return scalar
573+
574+
575+
def strtobool(val):
576+
"""
577+
Convert a string representation of truth to true (1) or false (0).
578+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
579+
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
580+
'val' is anything else.
581+
"""
582+
val = val.lower()
583+
if val in ("y", "yes", "t", "true", "on", "1"):
584+
return True
585+
elif val in ("n", "no", "f", "false", "off", "0"):
586+
return False
587+
else:
588+
raise ValueError(f"invalid truth value {val!r}")

redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def parse_url(url):
992992
try:
993993
kwargs[name] = parser(value)
994994
except (TypeError, ValueError):
995-
raise ValueError(f"Invalid value for `{name}` in connection URL.")
995+
raise ValueError(f"Invalid value for '{name}' in connection URL.")
996996
else:
997997
kwargs[name] = value
998998

tests/test_asyncio/test_connect.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import logging
32
import re
43
import socket
54
import ssl
@@ -14,9 +13,6 @@
1413

1514
from ..ssl_utils import get_ssl_filename
1615

17-
_logger = logging.getLogger(__name__)
18-
19-
2016
_CLIENT_NAME = "test-suite-client"
2117
_CMD_SEP = b"\r\n"
2218
_SUCCESS_RESP = b"+OK" + _CMD_SEP
@@ -125,7 +121,7 @@ async def test_tcp_ssl_version_mismatch(tcp_address):
125121
tcp_address,
126122
certfile=certfile,
127123
keyfile=keyfile,
128-
ssl_version=ssl.TLSVersion.TLSv1_2,
124+
maximum_ssl_version=ssl.TLSVersion.TLSv1_2,
129125
)
130126
await conn.disconnect()
131127

@@ -135,7 +131,8 @@ async def _assert_connect(
135131
server_address,
136132
certfile=None,
137133
keyfile=None,
138-
ssl_version=None,
134+
minimum_ssl_version=ssl.TLSVersion.TLSv1_2,
135+
maximum_ssl_version=ssl.TLSVersion.TLSv1_3,
139136
):
140137
stop_event = asyncio.Event()
141138
finished = asyncio.Event()
@@ -153,9 +150,8 @@ async def _handler(reader, writer):
153150
elif certfile:
154151
host, port = server_address
155152
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
156-
if ssl_version is not None:
157-
context.minimum_version = ssl_version
158-
context.maximum_version = ssl_version
153+
context.minimum_version = minimum_ssl_version
154+
context.maximum_version = maximum_ssl_version
159155
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
160156
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
161157
else:
@@ -178,23 +174,18 @@ async def _handler(reader, writer):
178174

179175

180176
async def _redis_request_handler(reader, writer, stop_event):
181-
buffer = b""
182177
command = None
183178
command_ptr = None
184179
fragment_length = None
185-
while not stop_event.is_set() or buffer:
186-
_logger.info(str(stop_event.is_set()))
187-
try:
188-
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5)
189-
except TimeoutError:
190-
continue
180+
while not stop_event.is_set():
181+
buffer = await reader.read(1024)
191182
if not buffer:
192-
continue
183+
break
193184
parts = re.split(_CMD_SEP, buffer)
194-
buffer = parts[-1]
195-
for fragment in parts[:-1]:
185+
for fragment in parts:
196186
fragment = fragment.decode()
197-
_logger.info("Command fragment: %s", fragment)
187+
if not fragment:
188+
continue
198189

199190
if fragment.startswith("*") and command is None:
200191
command = [None for _ in range(int(fragment[1:]))]
@@ -214,10 +205,7 @@ async def _redis_request_handler(reader, writer, stop_event):
214205
continue
215206

216207
command = " ".join(command)
217-
_logger.info("Command %s", command)
218208
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
219-
_logger.info("Response from %s", resp)
220209
writer.write(resp)
221210
await writer.drain()
222211
command = None
223-
_logger.info("Exit handler")

tests/test_connect.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
10-
from redis.exceptions import ConnectionError
10+
from redis.exceptions import RedisError
1111

1212
from .ssl_utils import get_ssl_filename
1313

@@ -126,16 +126,16 @@ def test_tcp_ssl_version_mismatch(tcp_address):
126126
port=port,
127127
client_name=_CLIENT_NAME,
128128
ssl_ca_certs=certfile,
129-
socket_timeout=10,
129+
socket_timeout=3,
130130
ssl_min_version=ssl.TLSVersion.TLSv1_3,
131131
)
132-
with pytest.raises(ConnectionError):
132+
with pytest.raises(RedisError):
133133
_assert_connect(
134134
conn,
135135
tcp_address,
136136
certfile=certfile,
137137
keyfile=keyfile,
138-
ssl_version=ssl.PROTOCOL_TLSv1_2,
138+
maximum_ssl_version=ssl.PROTOCOL_TLSv1_2,
139139
)
140140

141141

@@ -164,14 +164,16 @@ def __init__(
164164
*args,
165165
certfile=None,
166166
keyfile=None,
167-
ssl_version=ssl.PROTOCOL_TLS,
167+
minimum_ssl_version=ssl.TLSVersion.TLSv1_2,
168+
maximum_ssl_version=ssl.TLSVersion.TLSv1_3,
168169
**kw,
169170
) -> None:
170171
self._ready_event = threading.Event()
171172
self._stop_requested = False
172173
self._certfile = certfile
173174
self._keyfile = keyfile
174-
self._ssl_version = ssl_version
175+
self._minimum_ssl_version = minimum_ssl_version
176+
self._maximum_ssl_version = maximum_ssl_version
175177
super().__init__(*args, **kw)
176178

177179
def service_actions(self):
@@ -193,15 +195,9 @@ def get_request(self):
193195
newsocket, fromaddr = self.socket.accept()
194196
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
195197
context.load_cert_chain(certfile=self._certfile, keyfile=self._keyfile)
196-
context.minimum_version = ssl.TLSVersion.TLSv1_2
197-
context.maximum_version = ssl.TLSVersion.TLSv1_3
198-
connstream = context.wrap_socket(
199-
newsocket,
200-
server_side=True,
201-
certfile=self._certfile,
202-
keyfile=self._keyfile,
203-
ssl_version=self._ssl_version,
204-
)
198+
context.minimum_version = self._minimum_ssl_version
199+
context.maximum_version = self._maximum_ssl_version
200+
connstream = context.wrap_socket(newsocket, server_side=True)
205201
return connstream, fromaddr
206202

207203

0 commit comments

Comments
 (0)