Skip to content

Commit 078a131

Browse files
committed
(improvement)remove extra len() call in lz4_compress()
In essence, we know the length of the uncompressed payload, so pass it along. To do so, changed the signature of the function (and did the same as dummy for snappy). Signed-off-by: Yaniv Kaul <[email protected]>
1 parent c89f2a5 commit 078a131

File tree

4 files changed

+29
-20
lines changed

4 files changed

+29
-20
lines changed

cassandra/connection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@
8989
# but the lz4 lib requires little endian order, so we wrap these
9090
# functions to handle that
9191

92-
def lz4_compress(byts):
92+
def lz4_compress(byts, len_bytes):
9393
# write length in big-endian instead of little-endian
94-
return int32_pack(len(byts)) + lz4_block.compress(byts)[4:]
94+
return int32_pack(len_bytes) + lz4_block.compress(byts)[4:]
9595

9696
def lz4_decompress(byts):
9797
# flip from big-endian to little-endian
@@ -106,12 +106,16 @@ def lz4_decompress(byts):
106106
log.debug("snappy package could not be imported. Snappy Compression will not be available")
107107
pass
108108
else:
109+
# unused length field, to be compatible with lz4_compress signature
110+
def snappy_compress(byts, len_bytes):
111+
return snappy.compress(byts)
112+
109113
# work around apparently buggy snappy decompress
110-
def decompress(byts):
114+
def snappy_decompress(byts):
111115
if byts == '\x00':
112116
return ''
113117
return snappy.decompress(byts)
114-
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
118+
locally_supported_compressions['snappy'] = (snappy_compress, snappy_decompress)
115119

116120
DRIVER_NAME, DRIVER_VERSION = 'ScyllaDB Python Driver', sys.modules['cassandra'].__version__
117121

cassandra/protocol.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,11 +1098,12 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
10981098
write_bytesmap(body, msg.custom_payload)
10991099
msg.send_body(body, protocol_version)
11001100
body = body.getvalue()
1101-
1101+
body_length = len(body)
11021102
# With checksumming, the compression is done at the segment frame encoding
11031103
if (not ProtocolVersion.has_checksumming_support(protocol_version)
1104-
and compressor and len(body) > 0):
1105-
body = compressor(body)
1104+
and compressor and body_length > 0):
1105+
body = compressor(body, body_length)
1106+
body_length = len(body)
11061107
flags |= COMPRESSED_FLAG
11071108

11081109
if msg.tracing:
@@ -1112,7 +1113,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
11121113
flags |= USE_BETA_FLAG
11131114

11141115
buff = io.BytesIO()
1115-
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
1116+
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, body_length)
11161117
buff.write(body)
11171118

11181119
return buff.getvalue()

cassandra/segment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def header_length_with_crc(self):
116116
def compression(self):
117117
return self.compressor and self.decompressor
118118

119-
def compress(self, data):
119+
def compress(self, data, length):
120120
# the uncompressed length is already encoded in the header, so
121121
# we remove it here
122-
return self.compressor(data)[4:]
122+
return self.compressor(data, length)[4:]
123123

124124
def decompress(self, encoded_data, uncompressed_length):
125125
return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
@@ -150,7 +150,7 @@ def _encode_segment(self, buffer, payload, is_self_contained):
150150
uncompressed_payload_length = len(payload)
151151

152152
if self.compression:
153-
compressed_payload = self.compress(uncompressed_payload)
153+
compressed_payload = self.compress(uncompressed_payload, uncompressed_payload_length)
154154
if len(compressed_payload) >= uncompressed_payload_length:
155155
encoded_payload = uncompressed_payload
156156
payload_length = uncompressed_payload_length

tests/unit/test_segment.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def test_encode_uncompressed_header(self):
5757
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
5858
def test_encode_compressed_header(self):
5959
buffer = BytesIO()
60-
compressed_length = len(segment_codec_lz4.compress(self.small_msg))
61-
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)
60+
len_small_msg = len(self.small_msg)
61+
compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_small_msg))
62+
segment_codec_lz4.encode_header(buffer, compressed_length, len_small_msg, True)
6263

6364
assert buffer.tell() == 8
6465
assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000"
@@ -87,17 +88,19 @@ def test_encode_uncompressed_header_not_self_contained_msg(self):
8788
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
8889
def test_encode_compressed_header_with_max_payload(self):
8990
buffer = BytesIO()
90-
compressed_length = len(segment_codec_lz4.compress(self.max_msg))
91-
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True)
91+
len_self_max_msg = len(self.max_msg)
92+
compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg))
93+
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, True)
9294
assert buffer.tell() == 8
9395
assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000"
9496

9597
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
9698
def test_encode_compressed_header_not_self_contained_msg(self):
9799
buffer = BytesIO()
98100
# simulate the first chunk with the max size
99-
compressed_length = len(segment_codec_lz4.compress(self.max_msg))
100-
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False)
101+
len_self_max_msg = len(self.max_msg)
102+
compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg))
103+
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, False)
101104
assert buffer.tell() == 8
102105
assert self._header_to_bits(buffer.getvalue()) == ("{:017b}".format(compressed_length) +
103106
"11111111111111111"
@@ -116,11 +119,12 @@ def test_decode_uncompressed_header(self):
116119
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
117120
def test_decode_compressed_header(self):
118121
buffer = BytesIO()
119-
compressed_length = len(segment_codec_lz4.compress(self.small_msg))
120-
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)
122+
len_self_small_msg = len(self.small_msg)
123+
compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_self_small_msg))
124+
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_small_msg, True)
121125
buffer.seek(0)
122126
header = segment_codec_lz4.decode_header(buffer)
123-
assert header.uncompressed_payload_length == len(self.small_msg)
127+
assert header.uncompressed_payload_length == len_self_small_msg
124128
assert header.payload_length == compressed_length
125129
assert header.is_self_contained == True
126130

0 commit comments

Comments
 (0)