Skip to content

Commit 18982c6

Browse files
Parse LWT flags when creating prepared statement
1 parent cc55ad6 commit 18982c6

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed

cassandra/cluster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3109,7 +3109,9 @@ def prepare(self, query, custom_payload=None, keyspace=None):
31093109
prepared_keyspace = keyspace if keyspace else None
31103110
prepared_statement = PreparedStatement.from_message(
31113111
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
3112-
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
3112+
self._protocol_version, response.column_metadata, response.result_metadata_id,
3113+
response.lwt_info.is_lwt(response.flags) if response.lwt_info is not None else False,
3114+
self.cluster.column_encryption_policy)
31133115
prepared_statement.custom_payload = future.custom_payload
31143116

31153117
self.cluster.add_prepared(response.query_id, prepared_statement)

cassandra/lwt_info.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class _LwtInfo:
1919

2020
def __init__(self, lwt_meta_bit_mask):
2121
self.lwt_meta_bit_mask = lwt_meta_bit_mask
22+
23+
def is_lwt(self, flags):
24+
return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask

cassandra/protocol.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,19 +686,21 @@ class ResultMessage(_MessageType):
686686
bind_metadata = None
687687
pk_indexes = None
688688
schema_change_event = None
689+
flags = None
690+
lwt_info = None
689691

690692
def __init__(self, kind):
691693
self.kind = kind
692694

693-
def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
695+
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
694696
if self.kind == RESULT_KIND_VOID:
695697
return
696698
elif self.kind == RESULT_KIND_ROWS:
697699
self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
698700
elif self.kind == RESULT_KIND_SET_KEYSPACE:
699701
self.new_keyspace = read_string(f)
700702
elif self.kind == RESULT_KIND_PREPARED:
701-
self.recv_results_prepared(f, protocol_version, user_type_map)
703+
self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
702704
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
703705
self.recv_results_schema_change(f, protocol_version)
704706
else:
@@ -708,7 +710,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
708710
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
709711
kind = read_int(f)
710712
msg = cls(kind)
711-
msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
713+
msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
712714
return msg
713715

714716
def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
@@ -741,8 +743,9 @@ def decode_row(row):
741743
col_md[3].cql_parameterized_type(),
742744
str(e)))
743745

744-
def recv_results_prepared(self, f, protocol_version, user_type_map):
746+
def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
745747
self.query_id = read_binary_string(f)
748+
self.lwt_info = protocol_features.lwt_info
746749
if ProtocolVersion.uses_prepared_metadata(protocol_version):
747750
self.result_metadata_id = read_binary_string(f)
748751
else:
@@ -787,6 +790,7 @@ def recv_results_metadata(self, f, user_type_map):
787790

788791
def recv_prepared_metadata(self, f, protocol_version, user_type_map):
789792
flags = read_int(f)
793+
self.flags = flags
790794
colcount = read_int(f)
791795
pk_indexes = None
792796
if protocol_version >= 4:

cassandra/query.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def _set_serial_consistency_level(self, serial_consistency_level):
345345
def _del_serial_consistency_level(self):
346346
self._serial_consistency_level = None
347347

348+
def is_lwt(self):
349+
return False
350+
348351
serial_consistency_level = property(
349352
_get_serial_consistency_level,
350353
_set_serial_consistency_level,
@@ -454,10 +457,11 @@ class PreparedStatement(object):
454457
routing_key_indexes = None
455458
_routing_key_index_set = None
456459
serial_consistency_level = None # TODO never used?
460+
_is_lwt = False
457461

458462
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
459463
keyspace, protocol_version, result_metadata, result_metadata_id,
460-
column_encryption_policy=None):
464+
is_lwt=None, column_encryption_policy=None):
461465
self.column_metadata = column_metadata
462466
self.query_id = query_id
463467
self.routing_key_indexes = routing_key_indexes
@@ -468,15 +472,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
468472
self.result_metadata_id = result_metadata_id
469473
self.column_encryption_policy = column_encryption_policy
470474
self.is_idempotent = False
475+
self._is_lwt = is_lwt
471476

472477
@classmethod
473478
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
474479
query, prepared_keyspace, protocol_version, result_metadata,
475-
result_metadata_id, column_encryption_policy=None):
480+
result_metadata_id, is_lwt, column_encryption_policy=None):
476481
if not column_metadata:
477482
return PreparedStatement(column_metadata, query_id, None,
478483
query, prepared_keyspace, protocol_version, result_metadata,
479-
result_metadata_id, column_encryption_policy)
484+
result_metadata_id, is_lwt, column_encryption_policy)
480485

481486
if pk_indexes:
482487
routing_key_indexes = pk_indexes
@@ -502,7 +507,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
502507

503508
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
504509
query, prepared_keyspace, protocol_version, result_metadata,
505-
result_metadata_id, column_encryption_policy)
510+
result_metadata_id, is_lwt, column_encryption_policy)
506511

507512
def bind(self, values):
508513
"""
@@ -517,6 +522,9 @@ def is_routing_key_index(self, i):
517522
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
518523
return i in self._routing_key_index_set
519524

525+
def is_lwt(self):
526+
return self._is_lwt
527+
520528
def __str__(self):
521529
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
522530
return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -682,6 +690,9 @@ def routing_key(self):
682690

683691
return self._routing_key
684692

693+
def is_lwt(self):
694+
return self.prepared_statement.is_lwt()
695+
685696
def __str__(self):
686697
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
687698
return (u'<BoundStatement query="%s", values=%s, consistency=%s>' %

0 commit comments

Comments
 (0)