Skip to content

Commit 5c31bb6

Browse files
sylwiaszunejkodkropachev
authored andcommitted
Parse LWT flags when creating prepared statement
1 parent 77e0492 commit 5c31bb6

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

cassandra/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3110,7 +3110,7 @@ def prepare(self, query, custom_payload=None, keyspace=None):
31103110
prepared_keyspace = keyspace if keyspace else None
31113111
prepared_statement = PreparedStatement.from_message(
31123112
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
3113-
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
3113+
self._protocol_version, response.column_metadata, response.result_metadata_id, response.is_lwt, self.cluster.column_encryption_policy)
31143114
prepared_statement.custom_payload = future.custom_payload
31153115

31163116
self.cluster.add_prepared(response.query_id, prepared_statement)

cassandra/lwt_info.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class _LwtInfo:
1919

2020
def __init__(self, lwt_meta_bit_mask):
2121
self.lwt_meta_bit_mask = lwt_meta_bit_mask
22+
23+
def get_lwt_flag(self, flags):
24+
return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask

cassandra/protocol.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,19 +686,20 @@ class ResultMessage(_MessageType):
686686
bind_metadata = None
687687
pk_indexes = None
688688
schema_change_event = None
689+
is_lwt = False
689690

690691
def __init__(self, kind):
691692
self.kind = kind
692693

693-
def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
694+
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
694695
if self.kind == RESULT_KIND_VOID:
695696
return
696697
elif self.kind == RESULT_KIND_ROWS:
697698
self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
698699
elif self.kind == RESULT_KIND_SET_KEYSPACE:
699700
self.new_keyspace = read_string(f)
700701
elif self.kind == RESULT_KIND_PREPARED:
701-
self.recv_results_prepared(f, protocol_version, user_type_map)
702+
self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
702703
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
703704
self.recv_results_schema_change(f, protocol_version)
704705
else:
@@ -708,7 +709,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
708709
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
709710
kind = read_int(f)
710711
msg = cls(kind)
711-
msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
712+
msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
712713
return msg
713714

714715
def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
@@ -741,13 +742,13 @@ def decode_row(row):
741742
col_md[3].cql_parameterized_type(),
742743
str(e)))
743744

744-
def recv_results_prepared(self, f, protocol_version, user_type_map):
745+
def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
745746
self.query_id = read_binary_string(f)
746747
if ProtocolVersion.uses_prepared_metadata(protocol_version):
747748
self.result_metadata_id = read_binary_string(f)
748749
else:
749750
self.result_metadata_id = None
750-
self.recv_prepared_metadata(f, protocol_version, user_type_map)
751+
self.recv_prepared_metadata(f, protocol_version, protocol_features, user_type_map)
751752

752753
def recv_results_metadata(self, f, user_type_map):
753754
flags = read_int(f)
@@ -785,8 +786,9 @@ def recv_results_metadata(self, f, user_type_map):
785786

786787
self.column_metadata = column_metadata
787788

788-
def recv_prepared_metadata(self, f, protocol_version, user_type_map):
789+
def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_type_map):
789790
flags = read_int(f)
791+
self.is_lwt = protocol_features.lwt_info.get_lwt_flag(flags) if protocol_features.lwt_info is not None else False
790792
colcount = read_int(f)
791793
pk_indexes = None
792794
if protocol_version >= 4:

cassandra/query.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def _set_serial_consistency_level(self, serial_consistency_level):
345345
def _del_serial_consistency_level(self):
346346
self._serial_consistency_level = None
347347

348+
def is_lwt(self):
349+
return False
350+
348351
serial_consistency_level = property(
349352
_get_serial_consistency_level,
350353
_set_serial_consistency_level,
@@ -454,10 +457,11 @@ class PreparedStatement(object):
454457
routing_key_indexes = None
455458
_routing_key_index_set = None
456459
serial_consistency_level = None # TODO never used?
460+
_is_lwt = False
457461

458462
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
459463
keyspace, protocol_version, result_metadata, result_metadata_id,
460-
column_encryption_policy=None):
464+
is_lwt=False, column_encryption_policy=None):
461465
self.column_metadata = column_metadata
462466
self.query_id = query_id
463467
self.routing_key_indexes = routing_key_indexes
@@ -468,15 +472,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
468472
self.result_metadata_id = result_metadata_id
469473
self.column_encryption_policy = column_encryption_policy
470474
self.is_idempotent = False
475+
self._is_lwt = is_lwt
471476

472477
@classmethod
473478
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
474479
query, prepared_keyspace, protocol_version, result_metadata,
475-
result_metadata_id, column_encryption_policy=None):
480+
result_metadata_id, is_lwt, column_encryption_policy=None):
476481
if not column_metadata:
477482
return PreparedStatement(column_metadata, query_id, None,
478483
query, prepared_keyspace, protocol_version, result_metadata,
479-
result_metadata_id, column_encryption_policy)
484+
result_metadata_id, is_lwt, column_encryption_policy)
480485

481486
if pk_indexes:
482487
routing_key_indexes = pk_indexes
@@ -502,7 +507,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
502507

503508
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
504509
query, prepared_keyspace, protocol_version, result_metadata,
505-
result_metadata_id, column_encryption_policy)
510+
result_metadata_id, is_lwt, column_encryption_policy)
506511

507512
def bind(self, values):
508513
"""
@@ -517,6 +522,9 @@ def is_routing_key_index(self, i):
517522
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
518523
return i in self._routing_key_index_set
519524

525+
def is_lwt(self):
526+
return self._is_lwt
527+
520528
def __str__(self):
521529
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
522530
return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -682,6 +690,9 @@ def routing_key(self):
682690

683691
return self._routing_key
684692

693+
def is_lwt(self):
694+
return self.prepared_statement.is_lwt()
695+
685696
def __str__(self):
686697
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
687698
return (u'<BoundStatement query="%s", values=%s, consistency=%s>' %

0 commit comments

Comments
 (0)