Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@
# but the lz4 lib requires little endian order, so we wrap these
# functions to handle that

def lz4_compress(byts):
def lz4_compress(byts, len_bytes):
# write length in big-endian instead of little-endian
return int32_pack(len(byts)) + lz4_block.compress(byts)[4:]
return int32_pack(len_bytes) + lz4_block.compress(byts)[4:]

def lz4_decompress(byts):
# flip from big-endian to little-endian
Expand All @@ -106,12 +106,16 @@ def lz4_decompress(byts):
log.debug("snappy package could not be imported. Snappy Compression will not be available")
pass
else:
# unused length field, to be compatible with lz4_compress signature
def snappy_compress(byts, len_bytes):
return snappy.compress(byts)

# work around apparently buggy snappy decompress
def decompress(byts):
def snappy_decompress(byts):
if byts == '\x00':
return ''
return snappy.decompress(byts)
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
locally_supported_compressions['snappy'] = (snappy_compress, snappy_decompress)

DRIVER_NAME, DRIVER_VERSION = 'ScyllaDB Python Driver', sys.modules['cassandra'].__version__

Expand Down
9 changes: 5 additions & 4 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,11 +1098,12 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
write_bytesmap(body, msg.custom_payload)
msg.send_body(body, protocol_version)
body = body.getvalue()

body_length = len(body)
# With checksumming, the compression is done at the segment frame encoding
if (not ProtocolVersion.has_checksumming_support(protocol_version)
and compressor and len(body) > 0):
body = compressor(body)
and compressor and body_length > 0):
body = compressor(body, body_length)
body_length = len(body)
flags |= COMPRESSED_FLAG

if msg.tracing:
Expand All @@ -1112,7 +1113,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
flags |= USE_BETA_FLAG

buff = io.BytesIO()
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, body_length)
buff.write(body)

return buff.getvalue()
Expand Down
17 changes: 9 additions & 8 deletions cassandra/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def header_length_with_crc(self):
def compression(self):
return self.compressor and self.decompressor

def compress(self, data):
def compress(self, data, length):
# the uncompressed length is already encoded in the header, so
# we remove it here
return self.compressor(data)[4:]
return self.compressor(data, length)[4:]

def decompress(self, encoded_data, uncompressed_length):
return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
Expand Down Expand Up @@ -150,35 +150,36 @@ def _encode_segment(self, buffer, payload, is_self_contained):
uncompressed_payload_length = len(payload)

if self.compression:
compressed_payload = self.compress(uncompressed_payload)
compressed_payload = self.compress(uncompressed_payload, uncompressed_payload_length)
if len(compressed_payload) >= uncompressed_payload_length:
encoded_payload = uncompressed_payload
payload_length = uncompressed_payload_length
uncompressed_payload_length = 0
else:
encoded_payload = compressed_payload
payload_length = len(compressed_payload)
else:
encoded_payload = uncompressed_payload

payload_length = len(encoded_payload)
payload_length = uncompressed_payload_length
self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
buffer.write(encoded_payload)
write_uint_le(buffer, payload_crc)

def encode(self, buffer, msg):
"""
Encode a message to one of more segments.
Encode a message to one or more segments.
"""
msg_length = len(msg)

if msg_length > Segment.MAX_PAYLOAD_LENGTH:
payloads = []
for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
is_self_contained = False
else:
payloads = [msg]

is_self_contained = len(payloads) == 1
is_self_contained = True
for payload in payloads:
self._encode_segment(buffer, payload, is_self_contained)

Expand Down
22 changes: 13 additions & 9 deletions tests/unit/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def test_encode_uncompressed_header(self):
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
def test_encode_compressed_header(self):
buffer = BytesIO()
compressed_length = len(segment_codec_lz4.compress(self.small_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)
len_small_msg = len(self.small_msg)
compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_small_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len_small_msg, True)

assert buffer.tell() == 8
assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000"
Expand Down Expand Up @@ -87,17 +88,19 @@ def test_encode_uncompressed_header_not_self_contained_msg(self):
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
def test_encode_compressed_header_with_max_payload(self):
buffer = BytesIO()
compressed_length = len(segment_codec_lz4.compress(self.max_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True)
len_self_max_msg = len(self.max_msg)
compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, True)
assert buffer.tell() == 8
assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000"

@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
def test_encode_compressed_header_not_self_contained_msg(self):
buffer = BytesIO()
# simulate the first chunk with the max size
compressed_length = len(segment_codec_lz4.compress(self.max_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False)
len_self_max_msg = len(self.max_msg)
compressed_length = len(segment_codec_lz4.compress(self.max_msg, len_self_max_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_max_msg, False)
assert buffer.tell() == 8
assert self._header_to_bits(buffer.getvalue()) == ("{:017b}".format(compressed_length) +
"11111111111111111"
Expand All @@ -116,11 +119,12 @@ def test_decode_uncompressed_header(self):
@unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
def test_decode_compressed_header(self):
buffer = BytesIO()
compressed_length = len(segment_codec_lz4.compress(self.small_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)
len_self_small_msg = len(self.small_msg)
compressed_length = len(segment_codec_lz4.compress(self.small_msg, len_self_small_msg))
segment_codec_lz4.encode_header(buffer, compressed_length, len_self_small_msg, True)
buffer.seek(0)
header = segment_codec_lz4.decode_header(buffer)
assert header.uncompressed_payload_length == len(self.small_msg)
assert header.uncompressed_payload_length == len_self_small_msg
assert header.payload_length == compressed_length
assert header.is_self_contained == True

Expand Down
Loading