|
14 | 14 |
|
15 | 15 | import unittest |
16 | 16 |
|
17 | | -from cassandra.query import BatchStatement, SimpleStatement |
| 17 | +from cassandra.query import BatchStatement, PreparedStatement, SimpleStatement |
18 | 18 |
|
19 | 19 |
|
20 | 20 | class BatchStatementTest(unittest.TestCase): |
@@ -68,3 +68,50 @@ def test_len(self): |
68 | 68 | batch.add_all(statements=['%s'] * n, |
69 | 69 | parameters=[(i,) for i in range(n)]) |
70 | 70 | assert len(batch) == n |
| 71 | + |
| 72 | + def _make_prepared_statement(self, is_lwt=False): |
| 73 | + return PreparedStatement( |
| 74 | + column_metadata=[], |
| 75 | + query_id=b"query-id", |
| 76 | + routing_key_indexes=[], |
| 77 | + query="INSERT INTO test.table (id) VALUES (1)", |
| 78 | + keyspace=None, |
| 79 | + protocol_version=4, |
| 80 | + result_metadata=[], |
| 81 | + result_metadata_id=None, |
| 82 | + is_lwt=is_lwt, |
| 83 | + ) |
| 84 | + |
| 85 | + def test_is_lwt_false_for_non_lwt_statements(self): |
| 86 | + batch = BatchStatement() |
| 87 | + batch.add(self._make_prepared_statement(is_lwt=False)) |
| 88 | + batch.add(self._make_prepared_statement(is_lwt=False).bind(())) |
| 89 | + batch.add(SimpleStatement("INSERT INTO test.table (id) VALUES (3)")) |
| 90 | + batch.add("INSERT INTO test.table (id) VALUES (4)") |
| 91 | + assert batch.is_lwt() is False |
| 92 | + |
| 93 | + def test_is_lwt_propagates_from_statements(self): |
| 94 | + batch = BatchStatement() |
| 95 | + batch.add(self._make_prepared_statement(is_lwt=False)) |
| 96 | + assert batch.is_lwt() is False |
| 97 | + |
| 98 | + batch.add(self._make_prepared_statement(is_lwt=True)) |
| 99 | + assert batch.is_lwt() is True |
| 100 | + |
| 101 | + bound_lwt = self._make_prepared_statement(is_lwt=True).bind(()) |
| 102 | + batch_with_bound = BatchStatement() |
| 103 | + batch_with_bound.add(bound_lwt) |
| 104 | + assert batch_with_bound.is_lwt() is True |
| 105 | + |
| 106 | + class LwtSimpleStatement(SimpleStatement): |
| 107 | + def __init__(self): |
| 108 | + super(LwtSimpleStatement, self).__init__( |
| 109 | + "INSERT INTO test.table (id) VALUES (2) IF NOT EXISTS" |
| 110 | + ) |
| 111 | + |
| 112 | + def is_lwt(self): |
| 113 | + return True |
| 114 | + |
| 115 | + batch_with_simple = BatchStatement() |
| 116 | + batch_with_simple.add(LwtSimpleStatement()) |
| 117 | + assert batch_with_simple.is_lwt() is True |
0 commit comments