Skip to content

Commit 050ead4

Browse files
absurdfarcedkropachev
authored andcommitted
PYTHON-1341 Impl of client-side column-level encryption/decryption (datastax#1150)
1 parent 5a1e0da commit 050ead4

15 files changed

+573
-48
lines changed

cassandra/cluster.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,12 @@ def default_retry_policy(self, policy):
10341034
or to disable the shardaware port (advanced shardaware)
10351035
"""
10361036

1037+
column_encryption_policy = None
1038+
"""
1039+
An instance of :class:`cassandra.policies.ColumnEncryptionPolicy` specifying encryption materials to be
1040+
used for columns in this cluster.
1041+
"""
1042+
10371043
metadata_request_timeout = datetime.timedelta(seconds=2)
10381044
"""
10391045
Timeout for all queries used by driver it self.
@@ -1157,6 +1163,7 @@ def __init__(self,
11571163
scylla_cloud=None,
11581164
shard_aware_options=None,
11591165
metadata_request_timeout=None,
1166+
column_encryption_policy=None,
11601167
):
11611168
"""
11621169
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1234,6 +1241,9 @@ def __init__(self,
12341241

12351242
self.port = port
12361243

1244+
if column_encryption_policy is not None:
1245+
self.column_encryption_policy = column_encryption_policy
1246+
12371247
self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port)
12381248
self.endpoint_factory.configure(self)
12391249

@@ -2658,6 +2668,12 @@ def __init__(self, cluster, hosts, keyspace=None):
26582668

26592669
self.encoder = Encoder()
26602670

2671+
if self.cluster.column_encryption_policy is not None:
2672+
try:
2673+
self.client_protocol_handler.column_encryption_policy = self.cluster.column_encryption_policy
2674+
except AttributeError:
2675+
log.info("Unable to set column encryption policy for session")
2676+
26612677
# create connection pools in parallel
26622678
self._initial_connect_futures = set()
26632679
for host in hosts:
@@ -3197,7 +3213,7 @@ def prepare(self, query, custom_payload=None, keyspace=None):
31973213
prepared_keyspace = keyspace if keyspace else None
31983214
prepared_statement = PreparedStatement.from_message(
31993215
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
3200-
self._protocol_version, response.column_metadata, response.result_metadata_id)
3216+
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
32013217
prepared_statement.custom_payload = future.custom_payload
32023218

32033219
self.cluster.add_prepared(response.query_id, prepared_statement)

cassandra/obj_parser.pyx

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ include "ioutils.pyx"
1717
from cassandra import DriverException
1818
from cassandra.bytesio cimport BytesIOReader
1919
from cassandra.deserializers cimport Deserializer, from_binary
20+
from cassandra.deserializers import find_deserializer
2021
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
2122
from cassandra.tuple cimport tuple_new, tuple_set
2223

24+
from cpython.bytes cimport PyBytes_AsStringAndSize
25+
2326

2427
cdef class ListParser(ColumnParser):
2528
"""Decode a ResultMessage into a list of tuples (or other objects)"""
@@ -58,18 +61,29 @@ cdef class TupleRowParser(RowParser):
5861
assert desc.rowsize >= 0
5962

6063
cdef Buffer buf
64+
cdef Buffer newbuf
6165
cdef Py_ssize_t i, rowsize = desc.rowsize
6266
cdef Deserializer deserializer
6367
cdef tuple res = tuple_new(desc.rowsize)
6468

69+
ce_policy = desc.column_encryption_policy
6570
for i in range(rowsize):
6671
# Read the next few bytes
6772
get_buf(reader, &buf)
6873

6974
# Deserialize bytes to python object
7075
deserializer = desc.deserializers[i]
76+
coldesc = desc.coldescs[i]
77+
uses_ce = ce_policy and ce_policy.contains_column(coldesc)
7178
try:
72-
val = from_binary(deserializer, &buf, desc.protocol_version)
79+
if uses_ce:
80+
col_type = ce_policy.column_type(coldesc)
81+
decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf))
82+
PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size)
83+
deserializer = find_deserializer(ce_policy.column_type(coldesc))
84+
val = from_binary(deserializer, &newbuf, desc.protocol_version)
85+
else:
86+
val = from_binary(deserializer, &buf, desc.protocol_version)
7387
except Exception as e:
7488
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i],
7589
desc.coltypes[i].cql_parameterized_type(),

cassandra/parsing.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ from cassandra.deserializers cimport Deserializer
1818
cdef class ParseDesc:
1919
cdef public object colnames
2020
cdef public object coltypes
21+
cdef public object column_encryption_policy
22+
cdef public list coldescs
2123
cdef Deserializer[::1] deserializers
2224
cdef public int protocol_version
2325
cdef Py_ssize_t rowsize

cassandra/parsing.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ Module containing the definitions and declarations (parsing.pxd) for parsers.
1919
cdef class ParseDesc:
2020
"""Description of what structure to parse"""
2121

22-
def __init__(self, colnames, coltypes, deserializers, protocol_version):
22+
def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version):
2323
self.colnames = colnames
2424
self.coltypes = coltypes
25+
self.column_encryption_policy = column_encryption_policy
26+
self.coldescs = coldescs
2527
self.deserializers = deserializers
2628
self.protocol_version = protocol_version
2729
self.rowsize = len(colnames)

cassandra/policies.py

Lines changed: 177 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
16+
from collections import namedtuple
17+
from functools import lru_cache
1518
from itertools import islice, cycle, groupby, repeat
1619
import logging
20+
import os
1721
from random import randint, shuffle
1822
from threading import Lock
1923
import socket
2024
import warnings
25+
26+
from cryptography.hazmat.primitives import padding
27+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
28+
2129
from cassandra import WriteType as WT
22-
from cassandra.connection import UnixSocketEndPoint
30+
from cassandra.cqltypes import _cqltypes
2331

2432

2533
# This is done this way because WriteType was originally
@@ -572,8 +580,9 @@ def __init__(self, hosts):
572580
self._allowed_hosts = tuple(hosts)
573581
self._allowed_hosts_resolved = []
574582
for h in self._allowed_hosts:
575-
if isinstance(h, UnixSocketEndPoint):
576-
self._allowed_hosts_resolved.append(h._unix_socket_path)
583+
unix_socket_path = getattr(h, "_unix_socket_path", None)
584+
if unix_socket_path:
585+
self._allowed_hosts_resolved.append(unix_socket_path)
577586
else:
578587
self._allowed_hosts_resolved.extend([endpoint[4][0]
579588
for endpoint in socket.getaddrinfo(h, None, socket.AF_UNSPEC, socket.SOCK_STREAM)])
@@ -608,7 +617,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
608617
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
609618
and a single-argument predicate. This policy defers to the child policy for
610619
hosts where ``predicate(host)`` is truthy. Hosts for which
611-
``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
620+
``predicate(host)`` is falsy will be considered :attr:`.IGNORED`, and will
612621
not be used in a query plan.
613622
614623
This can be used in the cases where you need a whitelist or blacklist
@@ -644,7 +653,7 @@ def __init__(self, child_policy, predicate):
644653
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
645654
that this one will defer to.
646655
:param predicate: a one-parameter function that takes a :class:`.Host`.
647-
If it returns a falsey value, the :class:`.Host` will
656+
If it returns a falsy value, the :class:`.Host` will
648657
be :attr:`.IGNORED` and not returned in query plans.
649658
"""
650659
super(HostFilterPolicy, self).__init__()
@@ -680,7 +689,7 @@ def predicate(self):
680689
def distance(self, host):
681690
"""
682691
Checks if ``predicate(host)``, then returns
683-
:attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
692+
:attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy
684693
otherwise.
685694
"""
686695
if self.predicate(host):
@@ -769,7 +778,7 @@ class ReconnectionPolicy(object):
769778
def new_schedule(self):
770779
"""
771780
This should return a finite or infinite iterable of delays (each as a
772-
floating point number of seconds) inbetween each failed reconnection
781+
floating point number of seconds) in-between each failed reconnection
773782
attempt. Note that if the iterable is finite, reconnection attempts
774783
will cease once the iterable is exhausted.
775784
"""
@@ -779,12 +788,12 @@ def new_schedule(self):
779788
class ConstantReconnectionPolicy(ReconnectionPolicy):
780789
"""
781790
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
782-
inbetween each reconnection attempt.
791+
in-between each reconnection attempt.
783792
"""
784793

785794
def __init__(self, delay, max_attempts=64):
786795
"""
787-
`delay` should be a floating point number of seconds to wait inbetween
796+
`delay` should be a floating point number of seconds to wait in-between
788797
each attempt.
789798
790799
`max_attempts` should be a total number of attempts to be made before
@@ -808,7 +817,7 @@ def new_schedule(self):
808817
class ExponentialReconnectionPolicy(ReconnectionPolicy):
809818
"""
810819
A :class:`.ReconnectionPolicy` subclass which exponentially increases
811-
the length of the delay inbetween each reconnection attempt up to
820+
the length of the delay in-between each reconnection attempt up to
812821
a set maximum delay.
813822
814823
A random amount of jitter (+/- 15%) will be added to the pure exponential
@@ -868,7 +877,7 @@ class RetryPolicy(object):
868877
timeout and unavailable failures. These are failures reported from the
869878
server side. Timeouts are configured by
870879
`settings in cassandra.yaml <https://github.com/apache/cassandra/blob/cassandra-2.1.4/conf/cassandra.yaml#L568-L584>`_.
871-
Unavailable failures occur when the coordinator cannot acheive the consistency
880+
Unavailable failures occur when the coordinator cannot achieve the consistency
872881
level for a request. For further information see the method descriptions
873882
below.
874883
@@ -1385,3 +1394,160 @@ def _rethrow(self, *args, **kwargs):
13851394
on_read_timeout = _rethrow
13861395
on_write_timeout = _rethrow
13871396
on_unavailable = _rethrow
1397+
1398+
1399+
ColDesc = namedtuple('ColDesc', ['ks', 'table', 'col'])
1400+
ColData = namedtuple('ColData', ['key','type'])
1401+
1402+
class ColumnEncryptionPolicy(object):
1403+
"""
1404+
A policy enabling (mostly) transparent encryption and decryption of data before it is
1405+
sent to the cluster.
1406+
1407+
Key materials and other configurations are specified on a per-column basis. This policy can
1408+
then be used by driver structures which are aware of the underlying columns involved in their
1409+
work. In practice this includes the following cases:
1410+
1411+
* Prepared statements - data for columns specified by the cluster's policy will be transparently
1412+
encrypted before they are sent
1413+
* Rows returned from any query - data for columns specified by the cluster's policy will be
1414+
transparently decrypted before they are returned to the user
1415+
1416+
To enable this functionality, create an instance of this class (or more likely a subclass)
1417+
before creating a cluster. This policy should then be configured and supplied to the Cluster
1418+
at creation time via the :attr:`.Cluster.column_encryption_policy` attribute.
1419+
"""
1420+
1421+
def encrypt(self, coldesc, obj_bytes):
1422+
"""
1423+
Encrypt the specified bytes using the cryptography materials for the specified column.
1424+
Largely used internally, although this could also be used to encrypt values supplied
1425+
to non-prepared statements in a way that is consistent with this policy.
1426+
"""
1427+
raise NotImplementedError()
1428+
1429+
def decrypt(self, coldesc, encrypted_bytes):
1430+
"""
1431+
Decrypt the specified (encrypted) bytes using the cryptography materials for the
1432+
specified column. Used internally; could be used externally as well but there's
1433+
not currently an obvious use case.
1434+
"""
1435+
raise NotImplementedError()
1436+
1437+
def add_column(self, coldesc, key):
1438+
"""
1439+
Provide cryptography materials to be used when encrypted and/or decrypting data
1440+
for the specified column.
1441+
"""
1442+
raise NotImplementedError()
1443+
1444+
def contains_column(self, coldesc):
1445+
"""
1446+
Predicate to determine if a specific column is supported by this policy.
1447+
Currently only used internally.
1448+
"""
1449+
raise NotImplementedError()
1450+
1451+
def encode_and_encrypt(self, coldesc, obj):
1452+
"""
1453+
Helper function to enable use of this policy on simple (i.e. non-prepared)
1454+
statements.
1455+
"""
1456+
raise NotImplementedError()
1457+
1458+
AES256_BLOCK_SIZE = 128
1459+
AES256_BLOCK_SIZE_BYTES = int(AES256_BLOCK_SIZE / 8)
1460+
AES256_KEY_SIZE = 256
1461+
AES256_KEY_SIZE_BYTES = int(AES256_KEY_SIZE / 8)
1462+
1463+
class AES256ColumnEncryptionPolicy(ColumnEncryptionPolicy):
1464+
1465+
# CBC uses an IV that's the same size as the block size
1466+
#
1467+
# TODO: Need to find some way to expose mode options
1468+
# (CBC etc.) without leaking classes from the underlying
1469+
# impl here
1470+
def __init__(self, mode = modes.CBC, iv = os.urandom(AES256_BLOCK_SIZE_BYTES)):
1471+
1472+
self.mode = mode
1473+
self.iv = iv
1474+
1475+
# ColData for a given ColDesc is always preserved. We only create a Cipher
1476+
# when there's an actual need to for a given ColDesc
1477+
self.coldata = {}
1478+
self.ciphers = {}
1479+
1480+
def encrypt(self, coldesc, obj_bytes):
1481+
1482+
# AES256 has a 128-bit block size so if the input bytes don't align perfectly on
1483+
# those blocks we have to pad them. There's plenty of room for optimization here:
1484+
#
1485+
# * Instances of the PKCS7 padder should be managed in a bounded pool
1486+
# * It would be nice if we could get a flag from encrypted data to indicate
1487+
# whether it was padded or not
1488+
# * Might be able to make this happen with a leading block of flags in encrypted data
1489+
padder = padding.PKCS7(AES256_BLOCK_SIZE).padder()
1490+
padded_bytes = padder.update(obj_bytes) + padder.finalize()
1491+
1492+
cipher = self._get_cipher(coldesc)
1493+
encryptor = cipher.encryptor()
1494+
return encryptor.update(padded_bytes) + encryptor.finalize()
1495+
1496+
def decrypt(self, coldesc, encrypted_bytes):
1497+
1498+
cipher = self._get_cipher(coldesc)
1499+
decryptor = cipher.decryptor()
1500+
padded_bytes = decryptor.update(encrypted_bytes) + decryptor.finalize()
1501+
1502+
unpadder = padding.PKCS7(AES256_BLOCK_SIZE).unpadder()
1503+
return unpadder.update(padded_bytes) + unpadder.finalize()
1504+
1505+
def add_column(self, coldesc, key, type):
1506+
1507+
if not coldesc:
1508+
raise ValueError("ColDesc supplied to add_column cannot be None")
1509+
if not key:
1510+
raise ValueError("Key supplied to add_column cannot be None")
1511+
if not type:
1512+
raise ValueError("Type supplied to add_column cannot be None")
1513+
if type not in _cqltypes.keys():
1514+
raise ValueError("Type %s is not a supported type".format(type))
1515+
if not len(key) == AES256_KEY_SIZE_BYTES:
1516+
raise ValueError("AES256 column encryption policy expects a 256-bit encryption key")
1517+
self.coldata[coldesc] = ColData(key, _cqltypes[type])
1518+
1519+
def contains_column(self, coldesc):
1520+
return coldesc in self.coldata
1521+
1522+
def encode_and_encrypt(self, coldesc, obj):
1523+
if not coldesc:
1524+
raise ValueError("ColDesc supplied to encode_and_encrypt cannot be None")
1525+
if not obj:
1526+
raise ValueError("Object supplied to encode_and_encrypt cannot be None")
1527+
coldata = self.coldata.get(coldesc)
1528+
if not coldata:
1529+
raise ValueError("Could not find ColData for ColDesc %s".format(coldesc))
1530+
return self.encrypt(coldesc, coldata.type.serialize(obj, None))
1531+
1532+
def cache_info(self):
1533+
return AES256ColumnEncryptionPolicy._build_cipher.cache_info()
1534+
1535+
def column_type(self, coldesc):
1536+
return self.coldata[coldesc].type
1537+
1538+
def _get_cipher(self, coldesc):
1539+
"""
1540+
Access relevant state from this instance necessary to create a Cipher and then get one,
1541+
hopefully returning a cached instance if we've already done so (and it hasn't been evicted)
1542+
"""
1543+
1544+
try:
1545+
coldata = self.coldata[coldesc]
1546+
return AES256ColumnEncryptionPolicy._build_cipher(coldata.key, self.mode, self.iv)
1547+
except KeyError:
1548+
raise ValueError("Could not find column {}".format(coldesc))
1549+
1550+
# Explicitly use a class method here to avoid caching self
1551+
@lru_cache(maxsize=128)
1552+
def _build_cipher(key, mode, iv):
1553+
return Cipher(algorithms.AES256(key), mode(iv))

0 commit comments

Comments
 (0)