Skip to content

Commit dd09f58

Browse files
committed
Optimize column_encryption_policy checks in recv_results_rows
There's no point in checking a global policy for every single value decoding, not for every row decoded. Adjusted the code to only check it once per recv_results_rows() call - decode_row() should be defined either as is today with column_encryption_policy enabled, or much simpler without all those extra checks. Added a unit test from CoPilot. Fixes: #582 Signed-off-by: Yaniv Kaul <[email protected]>
1 parent dd1adc7 commit dd09f58

File tree

2 files changed

+177
-9
lines changed

2 files changed

+177
-9
lines changed

cassandra/protocol.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -719,24 +719,37 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata,
719719
rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
720720
self.column_names = [c[2] for c in column_metadata]
721721
self.column_types = [c[3] for c in column_metadata]
722-
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
723722

724-
def decode_val(val, col_md, col_desc):
725-
uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc)
726-
col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3]
727-
raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val
728-
return col_type.from_binary(raw_bytes, protocol_version)
723+
if column_encryption_policy:
724+
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
729725

730-
def decode_row(row):
731-
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
726+
def decode_val(val, col_md, col_desc):
727+
uses_ce = column_encryption_policy.contains_column(col_desc)
728+
if uses_ce:
729+
col_type = column_encryption_policy.column_type(col_desc)
730+
raw_bytes = column_encryption_policy.decrypt(col_desc, val)
731+
return col_type.from_binary(raw_bytes, protocol_version)
732+
else:
733+
return col_md[3].from_binary(val, protocol_version)
734+
735+
def decode_row(row):
736+
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
737+
else:
738+
def decode_row(row):
739+
return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata))
732740

733741
try:
734742
self.parsed_rows = [decode_row(row) for row in rows]
735743
except Exception:
744+
if not column_encryption_policy:
745+
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
736746
for row in rows:
737747
for val, col_md, col_desc in zip(row, column_metadata, col_descs):
738748
try:
739-
decode_val(val, col_md, col_desc)
749+
if column_encryption_policy:
750+
decode_val(val, col_md, col_desc)
751+
else:
752+
col_md[3].from_binary(val, protocol_version)
740753
except Exception as e:
741754
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
742755
col_md[3].cql_parameterized_type(),
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import io
16+
import unittest
17+
from unittest.mock import Mock
18+
19+
from cassandra import ProtocolVersion
20+
from cassandra.cqltypes import Int32Type, UTF8Type
21+
from cassandra.marshal import int32_pack
22+
from cassandra.policies import ColDesc
23+
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
24+
25+
26+
class DecodeOptimizationTest(unittest.TestCase):
27+
"""
28+
Tests to verify the optimization of column_encryption_policy checks
29+
in recv_results_rows. The optimization checks if the policy exists once
30+
per result message, avoiding the redundant 'column_encryption_policy and ...'
31+
check for every value.
32+
"""
33+
34+
def _create_mock_result_metadata(self):
35+
"""Create mock result metadata for testing"""
36+
return [
37+
('keyspace1', 'table1', 'col1', Int32Type),
38+
('keyspace1', 'table1', 'col2', UTF8Type),
39+
]
40+
41+
def _create_mock_result_message(self):
42+
"""Create a mock result message with data"""
43+
msg = ResultMessage(kind=RESULT_KIND_ROWS)
44+
msg.column_metadata = self._create_mock_result_metadata()
45+
msg.recv_results_metadata = Mock()
46+
msg.recv_row = Mock(side_effect=[
47+
[int32_pack(42), b'hello'],
48+
[int32_pack(100), b'world'],
49+
])
50+
return msg
51+
52+
def _create_mock_stream(self):
53+
"""Create a mock stream for reading rows"""
54+
# Pack rowcount (2 rows)
55+
data = int32_pack(2)
56+
return io.BytesIO(data)
57+
58+
def test_decode_without_encryption_policy(self):
59+
"""
60+
Test that decoding works correctly without column encryption policy.
61+
This should use the optimized simple path.
62+
"""
63+
msg = self._create_mock_result_message()
64+
f = self._create_mock_stream()
65+
66+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)
67+
68+
# Verify results
69+
self.assertEqual(len(msg.parsed_rows), 2)
70+
self.assertEqual(msg.parsed_rows[0][0], 42)
71+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
72+
self.assertEqual(msg.parsed_rows[1][0], 100)
73+
self.assertEqual(msg.parsed_rows[1][1], 'world')
74+
75+
def test_decode_with_encryption_policy_no_encrypted_columns(self):
76+
"""
77+
Test that decoding works with encryption policy when no columns are encrypted.
78+
"""
79+
msg = self._create_mock_result_message()
80+
f = self._create_mock_stream()
81+
82+
# Create mock encryption policy that has no encrypted columns
83+
mock_policy = Mock()
84+
mock_policy.contains_column = Mock(return_value=False)
85+
86+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
87+
88+
# Verify results
89+
self.assertEqual(len(msg.parsed_rows), 2)
90+
self.assertEqual(msg.parsed_rows[0][0], 42)
91+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
92+
93+
# Verify contains_column was called for each value (but policy existence check happens once)
94+
# Should be called 4 times (2 rows × 2 columns)
95+
self.assertEqual(mock_policy.contains_column.call_count, 4)
96+
97+
def test_decode_with_encryption_policy_with_encrypted_column(self):
98+
"""
99+
Test that decoding works with encryption policy when one column is encrypted.
100+
"""
101+
msg = self._create_mock_result_message()
102+
f = self._create_mock_stream()
103+
104+
# Create mock encryption policy where first column is encrypted
105+
mock_policy = Mock()
106+
def contains_column_side_effect(col_desc):
107+
return col_desc.col == 'col1'
108+
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
109+
mock_policy.column_type = Mock(return_value=Int32Type)
110+
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)
111+
112+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
113+
114+
# Verify results
115+
self.assertEqual(len(msg.parsed_rows), 2)
116+
self.assertEqual(msg.parsed_rows[0][0], 42)
117+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
118+
119+
# Verify contains_column was called for each value (but policy existence check happens once)
120+
# Should be called 4 times (2 rows × 2 columns)
121+
self.assertEqual(mock_policy.contains_column.call_count, 4)
122+
123+
# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
124+
self.assertEqual(mock_policy.decrypt.call_count, 2)
125+
126+
def test_optimization_efficiency(self):
127+
"""
128+
Verify that the optimization checks policy existence once per result message.
129+
The key optimization is checking 'if column_encryption_policy:' once,
130+
rather than 'column_encryption_policy and ...' for every value.
131+
"""
132+
msg = self._create_mock_result_message()
133+
134+
# Create more rows to make the check pattern clear
135+
msg.recv_row = Mock(side_effect=[
136+
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
137+
])
138+
139+
# Create mock stream with 100 rows
140+
f = io.BytesIO(int32_pack(100))
141+
142+
mock_policy = Mock()
143+
mock_policy.contains_column = Mock(return_value=False)
144+
145+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
146+
147+
# With optimization: policy existence checked once, contains_column called per value
148+
# = 100 rows * 2 columns = 200 calls to contains_column
149+
# The key is we avoid checking 'column_encryption_policy and ...' 200 times
150+
self.assertEqual(mock_policy.contains_column.call_count, 200,
151+
"contains_column should be called for each value when policy exists")
152+
153+
154+
if __name__ == '__main__':
155+
unittest.main()

0 commit comments

Comments
 (0)