Skip to content

Commit 5d529e1

Browse files
committed
Implement support of scylla cloud config bundle
```python path_to_bundle_yaml='/file/download/from/cloud/config.yaml' cluster= Cluster(scylla_cloud=path_to_bundle_yaml) ```
1 parent 4761032 commit 5d529e1

File tree

4 files changed

+166
-5
lines changed

4 files changed

+166
-5
lines changed

cassandra/cluster.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
GraphSON3Serializer)
9292
from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
9393
from cassandra.datastax import cloud as dscloud
94+
from cassandra.scylla.cloud import CloudConfiguration
9495

9596
try:
9697
from cassandra.io.twistedreactor import TwistedConnection
@@ -1137,6 +1138,7 @@ def __init__(self,
11371138
monitor_reporting_interval=30,
11381139
client_id=None,
11391140
cloud=None,
1141+
scylla_cloud=None,
11401142
shard_aware_options=None):
11411143
"""
11421144
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1157,6 +1159,21 @@ def __init__(self,
11571159
if connection_class is not None:
11581160
self.connection_class = connection_class
11591161

1162+
if scylla_cloud is not None:
1163+
if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options:
1164+
raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options "
1165+
"cannot be specified with a scylla cloud configuration")
1166+
1167+
uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
1168+
uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)
1169+
1170+
scylla_cloud_config = CloudConfiguration.create(scylla_cloud, pyopenssl=uses_twisted or uses_eventlet)
1171+
ssl_context = scylla_cloud_config.ssl_context
1172+
endpoint_factory = scylla_cloud_config.endpoint_factory
1173+
contact_points = scylla_cloud_config.contact_points
1174+
ssl_options = scylla_cloud_config.ssl_options
1175+
auth_provider = scylla_cloud_config.auth_provider
1176+
11601177
if cloud is not None:
11611178
self.cloud = cloud
11621179
if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options:

cassandra/connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,17 @@ def __repr__(self):
309309

310310
class SniEndPointFactory(EndPointFactory):
311311

312-
def __init__(self, proxy_address, port):
312+
def __init__(self, proxy_address, port, node_domain=None):
313313
self._proxy_address = proxy_address
314314
self._port = port
315+
self._node_domain = node_domain
315316

316317
def create(self, row):
317318
host_id = row.get("host_id")
318319
if host_id is None:
319320
raise ValueError("No host_id to create the SniEndPoint")
320-
321-
return SniEndPoint(self._proxy_address, str(host_id), self._port)
321+
address = "{}.{}".format(host_id, self._node_domain) if self._node_domain else str(host_id)
322+
return SniEndPoint(self._proxy_address, str(address), self._port)
322323

323324
def create_from_sni(self, sni):
324325
return SniEndPoint(self._proxy_address, sni, self._port)

cassandra/scylla/cloud.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright ScyllaDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import ssl
17+
import tempfile
18+
import base64
19+
from ssl import SSLContext
20+
from contextlib import contextmanager
21+
from itertools import islice
22+
23+
import six
24+
import yaml
25+
26+
from cassandra.connection import SniEndPointFactory
27+
from cassandra.auth import AuthProvider, PlainTextAuthProvider
28+
29+
30+
@contextmanager
31+
def file_or_memory(path=None, data=None):
32+
# since we can't read keys/cert from memory yet
33+
# see https://github.com/python/cpython/pull/2449 which isn't accepted and PEP-543 that was withdrawn
34+
# so we use temporary file to load the key
35+
if data:
36+
with tempfile.NamedTemporaryFile(mode="wb") as f:
37+
d = base64.decodebytes(bytes(data, encoding='utf-8'))
38+
f.write(d)
39+
if not d.endswith(b"\n"):
40+
f.write(b"\n")
41+
42+
f.flush()
43+
yield f.name
44+
45+
if path:
46+
yield path
47+
48+
49+
def nth(iterable, n, default=None):
50+
"Returns the nth item or a default value"
51+
return next(islice(iterable, n, None), default)
52+
53+
54+
class CloudConfiguration:
55+
endpoint_factory: SniEndPointFactory
56+
contact_points: list
57+
auth_provider: AuthProvider = None
58+
ssl_options: dict
59+
ssl_context: SSLContext
60+
skip_tls_verify: bool
61+
62+
def __init__(self, configuration_file, pyopenssl=False):
63+
cloud_config = yaml.safe_load(open(configuration_file))
64+
65+
self.current_context = cloud_config['contexts'][cloud_config['currentContext']]
66+
self.data_centers = cloud_config['datacenters']
67+
self.auth_info = cloud_config['authInfos'][self.current_context['authInfoName']]
68+
self.ssl_options = {}
69+
self.skip_tls_verify = self.auth_info.get('insecureSkipTLSVerify', False)
70+
self.ssl_context = self.create_pyopenssl_context() if pyopenssl else self.create_ssl_context()
71+
72+
proxy_address, port, node_domain = self.get_server(self.data_centers[self.current_context['datacenterName']],
73+
keys_order=['testServer', 'server'])
74+
self.endpoint_factory = SniEndPointFactory(proxy_address, port=int(port), node_domain=node_domain)
75+
76+
username, password = self.auth_info.get('username'), self.auth_info.get('password')
77+
if username and password:
78+
self.auth_provider = PlainTextAuthProvider(username, password)
79+
80+
81+
@property
82+
def contact_points(self):
83+
_contact_points = []
84+
for data_center in self.data_centers.values():
85+
address, _, _ = self.get_server(data_center)
86+
_contact_points.append(self.endpoint_factory.create_from_sni(address))
87+
return _contact_points
88+
89+
def get_server(self, data_center, keys_order=None):
90+
keys_order = keys_order or ['server']
91+
for key in keys_order:
92+
address = data_center.get(key, '')
93+
if not address:
94+
continue
95+
address = address.split(":")
96+
port = nth(address, 1, default=443)
97+
address = nth(address, 0)
98+
node_domain = data_center.get('nodeDomain')
99+
return address, port, node_domain
100+
101+
def create_ssl_context(self):
102+
ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_SSLv23)
103+
ssl_context.verify_mode = ssl.VerifyMode.CERT_NONE if self.skip_tls_verify else ssl.VerifyMode.CERT_REQUIRED
104+
for data_center in self.data_centers.values():
105+
with file_or_memory(path=data_center.get('certificateAuthorityPath'),
106+
data=data_center.get('certificateAuthorityData')) as cafile:
107+
ssl_context.load_verify_locations(cadata=open(cafile).read())
108+
with file_or_memory(path=self.auth_info.get('clientCertificatePath'),
109+
data=self.auth_info.get('clientCertificateData')) as certfile, \
110+
file_or_memory(path=self.auth_info.get('clientKeyPath'), data=self.auth_info.get('clientKeyData')) as keyfile:
111+
ssl_context.load_cert_chain(keyfile=keyfile,
112+
certfile=certfile)
113+
114+
return ssl_context
115+
116+
def create_pyopenssl_context(self):
117+
try:
118+
from OpenSSL import SSL
119+
except ImportError as e:
120+
six.reraise(
121+
ImportError,
122+
ImportError(
123+
"PyOpenSSL must be installed to connect to scylla-cloud with the Eventlet or Twisted event loops"),
124+
sys.exc_info()[2]
125+
)
126+
ssl_context = SSL.Context(SSL.TLS_METHOD)
127+
ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: True if self.skip_tls_verify else ok)
128+
for data_center in self.data_centers.values():
129+
with file_or_memory(path=data_center.get('certificateAuthorityPath'),
130+
data=data_center.get('certificateAuthorityData')) as cafile:
131+
ssl_context.load_verify_locations(cafile)
132+
with file_or_memory(path=self.auth_info.get('clientCertificatePath'),
133+
data=self.auth_info.get('clientCertificateData')) as certfile, \
134+
file_or_memory(path=self.auth_info.get('clientKeyPath'), data=self.auth_info.get('clientKeyData')) as keyfile:
135+
ssl_context.use_privatekey_file(keyfile)
136+
ssl_context.use_certificate_file(certfile)
137+
138+
return ssl_context
139+
140+
@classmethod
141+
def create(cls, configuration_file, pyopenssl=False):
142+
return cls(configuration_file, pyopenssl)

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ def run_setup(extensions):
404404
sys.stderr.write("Bypassing Cython setup requirement\n")
405405

406406
dependencies = ['six >=1.9',
407-
'geomet>=0.1,<0.3']
407+
'geomet>=0.1,<0.3',
408+
'pyyaml > 5.0']
408409

409410
if not PY3:
410411
dependencies.append('futures')
@@ -429,7 +430,7 @@ def run_setup(extensions):
429430
packages=[
430431
'cassandra', 'cassandra.io', 'cassandra.cqlengine', 'cassandra.graph',
431432
'cassandra.datastax', 'cassandra.datastax.insights', 'cassandra.datastax.graph',
432-
'cassandra.datastax.graph.fluent', 'cassandra.datastax.cloud'
433+
'cassandra.datastax.graph.fluent', 'cassandra.datastax.cloud', 'cassandra.scylla'
433434
],
434435
keywords='cassandra,cql,orm,dse,graph',
435436
include_package_data=True,

0 commit comments

Comments
 (0)