diff --git a/cassandra/connection.py b/cassandra/connection.py index a2540a967b..f090a5a3eb 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1110,7 +1110,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + msg = encoder(msg, request_id, self.protocol_version, self.features, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version) if self._is_checksumming_enabled: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 29ae404048..4ed1c7dfa8 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -421,7 +421,7 @@ def __init__(self, cqlversion, options): self.cqlversion = cqlversion self.options = options - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): optmap = self.options.copy() optmap['CQL_VERSION'] = self.cqlversion write_stringmap(f, optmap) @@ -456,7 +456,7 @@ class CredentialsMessage(_MessageType): def __init__(self, creds): self.creds = creds - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): if protocol_version > 1: raise UnsupportedOperation( "Credentials-based authentication is not supported with " @@ -487,7 +487,7 @@ class AuthResponseMessage(_MessageType): def __init__(self, response): self.response = response - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.response) @@ -507,7 +507,7 @@ class OptionsMessage(_MessageType): opcode = 0x05 name = 'OPTIONS' - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): pass @@ -606,6 +606,9 @@ def _write_query_params(self, f, protocol_version): "Keyspaces may only be set on queries with protocol version " "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + if self.skip_meta is not None and self.skip_meta: + flags |= _SKIP_METADATA_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): write_uint(f, flags) else: @@ -645,7 +648,7 @@ def __init__(self, query, consistency_level, serial_consistency_level=None, super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, False, continuous_paging_options, keyspace) - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.query) self._write_query_params(f, protocol_version) @@ -681,9 +684,9 @@ def _write_query_params(self, f, protocol_version): else: super(ExecuteMessage, self)._write_query_params(f, protocol_version) - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_string(f, self.query_id) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id: write_string(f, self.result_metadata_id) self._write_query_params(f, protocol_version) @@ -734,7 +737,7 @@ class ResultMessage(_MessageType): def __init__(self, kind): self.kind = kind - def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: @@ -742,7 +745,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: - self.recv_results_prepared(f, protocol_version, user_type_map) + self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map) elif self.kind == RESULT_KIND_SCHEMA_CHANGE: self.recv_results_schema_change(f, protocol_version) else: @@ -752,7 +755,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): kind = read_int(f) msg = cls(kind) - msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) return msg def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): @@ -785,9 +788,9 @@ def decode_row(row): col_md[3].cql_parameterized_type(), str(e))) - def recv_results_prepared(self, f, protocol_version, user_type_map): + def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): self.query_id = read_binary_string(f) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id: self.result_metadata_id = read_binary_string(f) else: self.result_metadata_id = None @@ -909,7 +912,7 @@ def __init__(self, query, keyspace=None): self.query = query self.keyspace = keyspace - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.query) flags = 0x00 @@ -953,7 +956,7 @@ def __init__(self, batch_type, queries, consistency_level, self.timestamp = timestamp self.keyspace = keyspace - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_byte(f, self.batch_type.value) write_short(f, len(self.queries)) for prepared, string_or_query_id, params in self.queries: @@ -1012,7 +1015,7 @@ class RegisterMessage(_MessageType): def __init__(self, event_list): self.event_list = event_list - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_stringlist(f, self.event_list) @@ -1086,7 +1089,7 @@ def __init__(self, op_type, op_id, next_pages=0): self.op_id = op_id self.next_pages = next_pages - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_int(f, self.op_type) write_int(f, self.op_id) if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE: @@ -1122,7 +1125,7 @@ class _ProtocolHandler(object): """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler""" @classmethod - def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): + def encode_message(cls, msg, stream_id, protocol_version, protocol_features, compressor, allow_beta_protocol_version): """ Encodes a message using the specified frame parameters, and compressor @@ -1138,7 +1141,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG write_bytesmap(body, msg.custom_payload) - msg.send_body(body, protocol_version) + msg.send_body(body, protocol_version, protocol_features) body = body.getvalue() # With checksumming, the compression is done at the segment frame encoding diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py index 4eb7019f84..84c108319a 100644 --- a/cassandra/protocol_features.py +++ b/cassandra/protocol_features.py @@ -7,25 +7,29 @@ RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR" TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1" +USE_METADATA_ID = "SCYLLA_USE_METADATA_ID" class ProtocolFeatures(object): rate_limit_error = None shard_id = 0 sharding_info = None tablets_routing_v1 = False + use_metadata_id = False - def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False): + def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, use_metadata_id=False): self.rate_limit_error = rate_limit_error self.shard_id = shard_id self.sharding_info = sharding_info self.tablets_routing_v1 = tablets_routing_v1 + self.use_metadata_id = use_metadata_id @staticmethod def parse_from_supported(supported): rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported) shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported) tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported) - return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1) + use_metadata_id = ProtocolFeatures.parse_metadata_id_info(supported) + return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, use_metadata_id) @staticmethod def maybe_parse_rate_limit_error(supported): @@ -49,6 +53,8 @@ def add_startup_options(self, options): options[RATE_LIMIT_ERROR_EXTENSION] = "" if self.tablets_routing_v1: options[TABLETS_ROUTING_V1] = "" + if self.use_metadata_id: + options[USE_METADATA_ID] = "" @staticmethod def parse_sharding_info(options): @@ -72,3 +78,7 @@ def parse_sharding_info(options): @staticmethod def parse_tablets_info(options): return TABLETS_ROUTING_V1 in options + + @staticmethod + def parse_metadata_id_info(options): + return USE_METADATA_ID in options diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 907f62f2bb..b614293a94 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -23,6 +23,7 @@ _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, BatchMessage ) +from cassandra.protocol_features import ProtocolFeatures from cassandra.query import BatchType from cassandra.marshal import uint32_unpack from cassandra.cluster import ContinuousPagingOptions @@ -43,11 +44,11 @@ def test_prepare_message(self): message = PrepareMessage("a") io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) io.reset_mock() - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) @@ -55,12 +56,12 @@ def test_execute_message(self): message = ExecuteMessage('1', [], 4) io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) io.reset_mock() message.result_metadata_id = 'foo' - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x03',), (b'foo',), @@ -80,11 +81,11 @@ def test_query_message(self): message = QueryMessage("a", 3) io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) io.reset_mock() - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) def _check_calls(self, io, expected): @@ -112,10 +113,10 @@ def test_continuous_paging(self): io = Mock() for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS if not ProtocolVersion.has_continuous_paging_support(version)]: - self.assertRaises(UnsupportedOperation, message.send_body, io, version) + self.assertRaises(UnsupportedOperation, message.send_body, io, version, ProtocolFeatures()) io.reset_mock() - message.send_body(io, ProtocolVersion.DSE_V1) + message.send_body(io, ProtocolVersion.DSE_V1, ProtocolFeatures()) # continuous paging adds two write calls to the buffer self.assertEqual(len(io.write.mock_calls), 6) @@ -142,7 +143,7 @@ def test_prepare_flag(self): message = PrepareMessage("a") io = Mock() for version in ProtocolVersion.SUPPORTED_VERSIONS: - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) if ProtocolVersion.uses_prepare_flags(version): self.assertEqual(len(io.write.mock_calls), 3) else: @@ -155,7 +156,7 @@ def test_prepare_flag_with_keyspace(self): for version in ProtocolVersion.SUPPORTED_VERSIONS: if ProtocolVersion.uses_keyspace_flag(version): - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) self._check_calls(io, [ (b'\x00\x00\x00\x01',), (b'a',), @@ -165,7 +166,7 @@ def test_prepare_flag_with_keyspace(self): ]) else: with self.assertRaises(UnsupportedOperation): - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) io.reset_mock() def test_keyspace_flag_raises_before_v5(self): @@ -173,7 +174,7 @@ def test_keyspace_flag_raises_before_v5(self): io = Mock(name='io') with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'): - keyspace_message.send_body(io, protocol_version=4) + keyspace_message.send_body(io, protocol_version=4, protocol_features=ProtocolFeatures()) io.assert_not_called() def test_keyspace_written_with_length(self): @@ -186,7 +187,7 @@ def test_keyspace_written_with_length(self): ] QueryMessage('a', consistency_level=3, keyspace='ks').send_body( - io, protocol_version=5 + io, protocol_version=5, protocol_features=ProtocolFeatures() ) self._check_calls(io, base_expected + [ (b'\x00\x02',), # length of keyspace string @@ -196,7 +197,7 @@ def test_keyspace_written_with_length(self): io.reset_mock() QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( - io, protocol_version=5 + io, protocol_version=5, protocol_features=ProtocolFeatures() ) self._check_calls(io, base_expected + [ (b'\x00\x08',), # length of keyspace string @@ -215,7 +216,7 @@ def test_batch_message_with_keyspace(self): consistency_level=3, keyspace='ks' ) - batch.send_body(io, protocol_version=5) + batch.send_body(io, protocol_version=5, protocol_features=ProtocolFeatures()) self._check_calls(io, ((b'\x00',), (b'\x00\x03',), (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt a',),