Skip to content

Commit 01eb2e8

Browse files
committed
Add support of LWT flag for Batch statements
1 parent 390e165 commit 01eb2e8

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

cassandra/query.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ class BatchStatement(Statement):
761761

762762
_statements_and_parameters = None
763763
_session = None
764+
_is_lwt = False
764765

765766
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
766767
consistency_level=None, serial_consistency_level=None,
@@ -845,13 +846,17 @@ def add(self, statement, parameters=None):
845846
query_id = statement.query_id
846847
bound_statement = statement.bind(() if parameters is None else parameters)
847848
self._update_state(bound_statement)
849+
if statement.is_lwt():
850+
self._is_lwt = True
848851
self._add_statement_and_params(True, query_id, bound_statement.values)
849852
elif isinstance(statement, BoundStatement):
850853
if parameters:
851854
raise ValueError(
852855
"Parameters cannot be passed with a BoundStatement "
853856
"to BatchStatement.add()")
854857
self._update_state(statement)
858+
if statement.is_lwt():
859+
self._is_lwt = True
855860
self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
856861
else:
857862
# it must be a SimpleStatement
@@ -860,6 +865,8 @@ def add(self, statement, parameters=None):
860865
encoder = Encoder() if self._session is None else self._session.encoder
861866
query_string = bind_params(query_string, parameters, encoder)
862867
self._update_state(statement)
868+
if statement.is_lwt():
869+
self._is_lwt = True
863870
self._add_statement_and_params(False, query_string, ())
864871
return self
865872

@@ -893,6 +900,9 @@ def _update_state(self, statement):
893900
self._maybe_set_routing_attributes(statement)
894901
self._update_custom_payload(statement)
895902

903+
def is_lwt(self):
904+
return self._is_lwt
905+
896906
def __len__(self):
897907
return len(self._statements_and_parameters)
898908

tests/unit/test_query.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import unittest
1616

17-
from cassandra.query import BatchStatement, SimpleStatement
17+
from cassandra.query import BatchStatement, PreparedStatement, SimpleStatement
1818

1919

2020
class BatchStatementTest(unittest.TestCase):
@@ -68,3 +68,50 @@ def test_len(self):
6868
batch.add_all(statements=['%s'] * n,
6969
parameters=[(i,) for i in range(n)])
7070
assert len(batch) == n
71+
72+
def _make_prepared_statement(self, is_lwt=False):
73+
return PreparedStatement(
74+
column_metadata=[],
75+
query_id=b"query-id",
76+
routing_key_indexes=[],
77+
query="INSERT INTO test.table (id) VALUES (1)",
78+
keyspace=None,
79+
protocol_version=4,
80+
result_metadata=[],
81+
result_metadata_id=None,
82+
is_lwt=is_lwt,
83+
)
84+
85+
def test_is_lwt_false_for_non_lwt_statements(self):
86+
batch = BatchStatement()
87+
batch.add(self._make_prepared_statement(is_lwt=False))
88+
batch.add(self._make_prepared_statement(is_lwt=False).bind(()))
89+
batch.add(SimpleStatement("INSERT INTO test.table (id) VALUES (3)"))
90+
batch.add("INSERT INTO test.table (id) VALUES (4)")
91+
assert batch.is_lwt() is False
92+
93+
def test_is_lwt_propagates_from_statements(self):
94+
batch = BatchStatement()
95+
batch.add(self._make_prepared_statement(is_lwt=False))
96+
assert batch.is_lwt() is False
97+
98+
batch.add(self._make_prepared_statement(is_lwt=True))
99+
assert batch.is_lwt() is True
100+
101+
bound_lwt = self._make_prepared_statement(is_lwt=True).bind(())
102+
batch_with_bound = BatchStatement()
103+
batch_with_bound.add(bound_lwt)
104+
assert batch_with_bound.is_lwt() is True
105+
106+
class LwtSimpleStatement(SimpleStatement):
107+
def __init__(self):
108+
super(LwtSimpleStatement, self).__init__(
109+
"INSERT INTO test.table (id) VALUES (2) IF NOT EXISTS"
110+
)
111+
112+
def is_lwt(self):
113+
return True
114+
115+
batch_with_simple = BatchStatement()
116+
batch_with_simple.add(LwtSimpleStatement())
117+
assert batch_with_simple.is_lwt() is True

0 commit comments

Comments
 (0)