diff --git a/cassandra/connection.py b/cassandra/connection.py index 7cd104ab29..650273454d 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -89,9 +89,9 @@ # but the lz4 lib requires little endian order, so we wrap these # functions to handle that - def lz4_compress(byts): + def lz4_compress(byts, len_bytes): # write length in big-endian instead of little-endian - return int32_pack(len(byts)) + lz4_block.compress(byts)[4:] + return int32_pack(len_bytes) + lz4_block.compress(byts)[4:] def lz4_decompress(byts): # flip from big-endian to little-endian @@ -106,12 +106,16 @@ def lz4_decompress(byts): log.debug("snappy package could not be imported. Snappy Compression will not be available") pass else: + # unused length field, to be compatible with lz4_compress signature + def snappy_compress(byts, len_bytes): + return snappy.compress(byts) + # work around apparently buggy snappy decompress - def decompress(byts): + def snappy_decompress(byts): if byts == '\x00': return '' return snappy.decompress(byts) - locally_supported_compressions['snappy'] = (snappy.compress, decompress) + locally_supported_compressions['snappy'] = (snappy_compress, snappy_decompress) DRIVER_NAME, DRIVER_VERSION = 'ScyllaDB Python Driver', sys.modules['cassandra'].__version__ diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 5fe4ed2be4..14a19380ef 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1098,11 +1098,12 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta write_bytesmap(body, msg.custom_payload) msg.send_body(body, protocol_version) body = body.getvalue() - + body_length = len(body) # With checksumming, the compression is done at the segment frame encoding if (not ProtocolVersion.has_checksumming_support(protocol_version) - and compressor and len(body) > 0): - body = compressor(body) + and compressor and body_length > 0): + body = compressor(body, body_length) + body_length = len(body) flags |= COMPRESSED_FLAG if msg.tracing: @@ -1112,7 +1113,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta flags |= USE_BETA_FLAG buff = io.BytesIO() - cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) + cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, body_length) buff.write(body) return buff.getvalue() diff --git a/cassandra/segment.py b/cassandra/segment.py index 78161fe520..5b87d76f1e 100644 --- a/cassandra/segment.py +++ b/cassandra/segment.py @@ -116,10 +116,10 @@ def header_length_with_crc(self): def compression(self): return self.compressor and self.decompressor - def compress(self, data): + def compress(self, data, length): # the uncompressed length is already encoded in the header, so # we remove it here - return self.compressor(data)[4:] + return self.compressor(data, length)[4:] def decompress(self, encoded_data, uncompressed_length): return self.decompressor(int32_pack(uncompressed_length) + encoded_data) @@ -150,16 +150,17 @@ def _encode_segment(self, buffer, payload, is_self_contained): uncompressed_payload_length = len(payload) if self.compression: - compressed_payload = self.compress(uncompressed_payload) + compressed_payload = self.compress(uncompressed_payload, uncompressed_payload_length) if len(compressed_payload) >= uncompressed_payload_length: encoded_payload = uncompressed_payload + payload_length = uncompressed_payload_length uncompressed_payload_length = 0 else: encoded_payload = compressed_payload + payload_length = len(compressed_payload) else: encoded_payload = uncompressed_payload - - payload_length = len(encoded_payload) + payload_length = uncompressed_payload_length self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained) payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL) buffer.write(encoded_payload) @@ -167,7 +168,7 @@ def _encode_segment(self, buffer, payload, is_self_contained): def encode(self, buffer, msg): """ - Encode a message to one of more segments. + Encode a message to one or more segments. """ msg_length = len(msg) @@ -175,10 +176,10 @@ def encode(self, buffer, msg): payloads = [] for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH): payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH]) + is_self_contained = False else: payloads = [msg] - - is_self_contained = len(payloads) == 1 + is_self_contained = True for payload in payloads: self._encode_segment(buffer, payload, is_self_contained) diff --git a/tests/unit/test_segment.py b/tests/unit/test_segment.py index bfc273db05..b5bde550b7 100644 --- a/tests/unit/test_segment.py +++ b/tests/unit/test_segment.py @@ -57,8 +57,9 @@ def test_encode_uncompressed_header(self): @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_encode_compressed_header(self): buffer = BytesIO() - compressed_length = len(segment_codec_lz4.compress(self.small_msg)) - segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) + len_small_msg = len(self.small_msg) + compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_small_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len_small_msg, True) assert buffer.tell() == 8 assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000" @@ -87,8 +88,9 @@ def test_encode_uncompressed_header_not_self_contained_msg(self): @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_encode_compressed_header_with_max_payload(self): buffer = BytesIO() - compressed_length = len(segment_codec_lz4.compress(self.max_msg)) - segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True) + len_self_max_msg = len(self.max_msg) + compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, True) assert buffer.tell() == 8 assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000" @@ -96,8 +98,9 @@ def test_encode_compressed_header_with_max_payload(self): def test_encode_compressed_header_not_self_contained_msg(self): buffer = BytesIO() # simulate the first chunk with the max size - compressed_length = len(segment_codec_lz4.compress(self.max_msg)) - segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False) + len_self_max_msg = len(self.max_msg) + compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, False) assert buffer.tell() == 8 assert self._header_to_bits(buffer.getvalue()) == ("{:017b}".format(compressed_length) + "11111111111111111" @@ -116,11 +119,12 @@ def test_decode_uncompressed_header(self): @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_decode_compressed_header(self): buffer = BytesIO() - compressed_length = len(segment_codec_lz4.compress(self.small_msg)) - segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) + len_self_small_msg = len(self.small_msg) + compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_self_small_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len_self_small_msg, True) buffer.seek(0) header = segment_codec_lz4.decode_header(buffer) - assert header.uncompressed_payload_length == len(self.small_msg) + assert header.uncompressed_payload_length == len_self_small_msg assert header.payload_length == compressed_length assert header.is_self_contained == True