Skip to content

Commit 2fc623a

Browse files
Copilotmykaul
andcommitted
Optimize column_encryption_policy checks in recv_results_rows
- Check column_encryption_policy once per recv_results_rows call instead of per value - Create two separate decode paths: one with encryption, one without - Pre-compute encryption info per column to avoid repeated lookups - Add comprehensive unit tests to verify optimization Co-authored-by: mykaul <[email protected]>
1 parent 89ac019 commit 2fc623a

File tree

2 files changed

+205
-20
lines changed

2 files changed

+205
-20
lines changed

cassandra/protocol.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -720,26 +720,59 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata,
720720
self.column_types = [c[3] for c in column_metadata]
721721
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
722722

723-
def decode_val(val, col_md, col_desc):
724-
uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc)
725-
col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3]
726-
raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val
727-
return col_type.from_binary(raw_bytes, protocol_version)
728-
729-
def decode_row(row):
730-
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
731-
732-
try:
733-
self.parsed_rows = [decode_row(row) for row in rows]
734-
except Exception:
735-
for row in rows:
736-
for val, col_md, col_desc in zip(row, column_metadata, col_descs):
737-
try:
738-
decode_val(val, col_md, col_desc)
739-
except Exception as e:
740-
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
741-
col_md[3].cql_parameterized_type(),
742-
str(e)))
723+
# Optimize by checking column_encryption_policy once and creating appropriate decode path
724+
if column_encryption_policy:
725+
# Pre-compute encryption info for each column to avoid repeated lookups
726+
column_encryption_info = [
727+
(column_encryption_policy.contains_column(col_desc), col_desc)
728+
for col_desc in col_descs
729+
]
730+
731+
def decode_val_with_encryption(val, col_md, uses_ce, col_desc):
732+
if uses_ce:
733+
col_type = column_encryption_policy.column_type(col_desc)
734+
raw_bytes = column_encryption_policy.decrypt(col_desc, val)
735+
else:
736+
col_type = col_md[3]
737+
raw_bytes = val
738+
return col_type.from_binary(raw_bytes, protocol_version)
739+
740+
def decode_row(row):
741+
return tuple(
742+
decode_val_with_encryption(val, col_md, uses_ce, col_desc)
743+
for val, col_md, (uses_ce, col_desc) in zip(row, column_metadata, column_encryption_info)
744+
)
745+
746+
try:
747+
self.parsed_rows = [decode_row(row) for row in rows]
748+
except Exception:
749+
for row in rows:
750+
for val, col_md, (uses_ce, col_desc) in zip(row, column_metadata, column_encryption_info):
751+
try:
752+
decode_val_with_encryption(val, col_md, uses_ce, col_desc)
753+
except Exception as e:
754+
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
755+
col_md[3].cql_parameterized_type(),
756+
str(e)))
757+
else:
758+
# Simple path without encryption - just decode raw bytes directly
759+
def decode_val_simple(val, col_type):
760+
return col_type.from_binary(val, protocol_version)
761+
762+
def decode_row(row):
763+
return tuple(decode_val_simple(val, col_md[3]) for val, col_md in zip(row, column_metadata))
764+
765+
try:
766+
self.parsed_rows = [decode_row(row) for row in rows]
767+
except Exception:
768+
for row in rows:
769+
for val, col_md in zip(row, column_metadata):
770+
try:
771+
decode_val_simple(val, col_md[3])
772+
except Exception as e:
773+
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
774+
col_md[3].cql_parameterized_type(),
775+
str(e)))
743776

744777
def recv_results_prepared(self, f, protocol_version, user_type_map):
745778
self.query_id = read_binary_string(f)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest.mock import Mock, MagicMock
17+
import io
18+
19+
from cassandra import ProtocolVersion
20+
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
21+
from cassandra.cqltypes import Int32Type, UTF8Type
22+
from cassandra.policies import ColDesc
23+
from cassandra.marshal import int32_pack
24+
25+
26+
class DecodeOptimizationTest(unittest.TestCase):
27+
"""
28+
Tests to verify the optimization of column_encryption_policy checks
29+
in recv_results_rows. The optimization should avoid checking the policy
30+
for every value and instead check once per recv_results_rows call.
31+
"""
32+
33+
def _create_mock_result_metadata(self):
34+
"""Create mock result metadata for testing"""
35+
return [
36+
('keyspace1', 'table1', 'col1', Int32Type),
37+
('keyspace1', 'table1', 'col2', UTF8Type),
38+
]
39+
40+
def _create_mock_result_message(self):
41+
"""Create a mock result message with data"""
42+
msg = ResultMessage(kind=RESULT_KIND_ROWS)
43+
msg.column_metadata = self._create_mock_result_metadata()
44+
msg.recv_results_metadata = Mock()
45+
msg.recv_row = Mock(side_effect=[
46+
[int32_pack(42), b'hello'],
47+
[int32_pack(100), b'world'],
48+
])
49+
return msg
50+
51+
def _create_mock_stream(self):
52+
"""Create a mock stream for reading rows"""
53+
# Pack rowcount (2 rows)
54+
data = int32_pack(2)
55+
return io.BytesIO(data)
56+
57+
def test_decode_without_encryption_policy(self):
58+
"""
59+
Test that decoding works correctly without column encryption policy.
60+
This should use the optimized simple path.
61+
"""
62+
msg = self._create_mock_result_message()
63+
f = self._create_mock_stream()
64+
65+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)
66+
67+
# Verify results
68+
self.assertEqual(len(msg.parsed_rows), 2)
69+
self.assertEqual(msg.parsed_rows[0][0], 42)
70+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
71+
self.assertEqual(msg.parsed_rows[1][0], 100)
72+
self.assertEqual(msg.parsed_rows[1][1], 'world')
73+
74+
def test_decode_with_encryption_policy_no_encrypted_columns(self):
75+
"""
76+
Test that decoding works with encryption policy when no columns are encrypted.
77+
"""
78+
msg = self._create_mock_result_message()
79+
f = self._create_mock_stream()
80+
81+
# Create mock encryption policy that has no encrypted columns
82+
mock_policy = Mock()
83+
mock_policy.contains_column = Mock(return_value=False)
84+
85+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
86+
87+
# Verify results
88+
self.assertEqual(len(msg.parsed_rows), 2)
89+
self.assertEqual(msg.parsed_rows[0][0], 42)
90+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
91+
92+
# Verify contains_column was called only once per column (optimization check)
93+
# Should be called 2 times total (once per column, not per value per row)
94+
self.assertEqual(mock_policy.contains_column.call_count, 2)
95+
96+
def test_decode_with_encryption_policy_with_encrypted_column(self):
97+
"""
98+
Test that decoding works with encryption policy when one column is encrypted.
99+
"""
100+
msg = self._create_mock_result_message()
101+
f = self._create_mock_stream()
102+
103+
# Create mock encryption policy where first column is encrypted
104+
mock_policy = Mock()
105+
def contains_column_side_effect(col_desc):
106+
return col_desc.col == 'col1'
107+
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
108+
mock_policy.column_type = Mock(return_value=Int32Type)
109+
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)
110+
111+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
112+
113+
# Verify results
114+
self.assertEqual(len(msg.parsed_rows), 2)
115+
self.assertEqual(msg.parsed_rows[0][0], 42)
116+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
117+
118+
# Verify contains_column was called only once per column (optimization)
119+
self.assertEqual(mock_policy.contains_column.call_count, 2)
120+
121+
# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
122+
self.assertEqual(mock_policy.decrypt.call_count, 2)
123+
124+
def test_optimization_efficiency(self):
125+
"""
126+
Verify that the optimization reduces the number of policy checks.
127+
With the old code, contains_column would be called for every value.
128+
With the new code, it's called once per column.
129+
"""
130+
msg = self._create_mock_result_message()
131+
132+
# Create more rows to make the optimization more apparent
133+
msg.recv_row = Mock(side_effect=[
134+
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
135+
])
136+
137+
# Create mock stream with 100 rows
138+
f = io.BytesIO(int32_pack(100))
139+
140+
mock_policy = Mock()
141+
mock_policy.contains_column = Mock(return_value=False)
142+
143+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
144+
145+
# With optimization: contains_column called once per column = 2 calls
146+
# Without optimization: would be called per value = 100 rows * 2 columns = 200 calls
147+
self.assertEqual(mock_policy.contains_column.call_count, 2,
148+
"Optimization failed: contains_column should be called once per column, not per value")
149+
150+
151+
if __name__ == '__main__':
152+
unittest.main()

0 commit comments

Comments
 (0)