Skip to content

Commit 02117bc

Browse files
committed
shard_aware: adding shard_aware_options to Cluster options
In some cases users don't want the automatic opening of so many connections (num of shard * num of nodes), this is adding a new Cluster parameter that can disable shard awareness ```python cluster = Cluster(contact_points=["127.0.0.1"], shard_aware_options=dict(disable=True), load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) ```
1 parent 6eaafc3 commit 02117bc

File tree

6 files changed

+55
-13
lines changed

6 files changed

+55
-13
lines changed

cassandra/cluster.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,20 @@ def default(self):
553553
"""
554554

555555

556+
class ShardAwareOptions:
557+
disable = None
558+
disable_shardaware_port = False
559+
560+
def __init__(self, opts=None, disable=None, disable_shardaware_port=None):
561+
self.disable = disable
562+
self.disable_shardaware_port = disable_shardaware_port
563+
if opts:
564+
if isinstance(opts, ShardAwareOptions):
565+
self.__dict__.update(opts.__dict__)
566+
elif isinstance(opts, dict):
567+
self.__dict__.update(opts)
568+
569+
556570
class _ConfigMode(object):
557571
UNCOMMITTED = 0
558572
LEGACY = 1
@@ -1003,6 +1017,12 @@ def default_retry_policy(self, policy):
10031017
load the configuration and certificates.
10041018
"""
10051019

1020+
shard_aware_options = None
1021+
"""
1022+
Can be set with :class:`ShardAwareOptions` or with a dict, to disable the automatic shardaware,
1023+
or to disable the shardaware port (advanced shardaware)
1024+
"""
1025+
10061026
@property
10071027
def schema_metadata_enabled(self):
10081028
"""
@@ -1104,7 +1124,8 @@ def __init__(self,
11041124
monitor_reporting_enabled=True,
11051125
monitor_reporting_interval=30,
11061126
client_id=None,
1107-
cloud=None):
1127+
cloud=None,
1128+
shard_aware_options=None):
11081129
"""
11091130
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
11101131
extablishing connection pools or refreshing metadata.
@@ -1304,6 +1325,7 @@ def __init__(self,
13041325
self.reprepare_on_up = reprepare_on_up
13051326
self.monitor_reporting_enabled = monitor_reporting_enabled
13061327
self.monitor_reporting_interval = monitor_reporting_interval
1328+
self.shard_aware_options = ShardAwareOptions(opts=shard_aware_options)
13071329

13081330
self._listeners = set()
13091331
self._listener_lock = Lock()

cassandra/pool.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,7 @@ def __init__(self, host, host_distance, session):
430430

431431
if self._keyspace:
432432
first_connection.set_keyspace_blocking(self._keyspace)
433-
434-
if first_connection.sharding_info:
433+
if first_connection.sharding_info and not self._session.cluster.shard_aware_options.disable:
435434
self.host.sharding_info = first_connection.sharding_info
436435
self._open_connections_for_all_shards(first_connection.shard_id)
437436

@@ -446,7 +445,7 @@ def _get_connection_for_routing_key(self, routing_key=None):
446445
raise NoConnectionsAvailable()
447446

448447
shard_id = None
449-
if self.host.sharding_info and routing_key:
448+
if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key:
450449
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
451450
shard_id = self.host.sharding_info.shard_id_from_token(t.value)
452451

@@ -585,7 +584,7 @@ def _replace(self, connection):
585584
try:
586585
if connection.shard_id in self._connections.keys():
587586
del self._connections[connection.shard_id]
588-
if self.host.sharding_info:
587+
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
589588
self._connecting.add(connection.shard_id)
590589
self._session.submit(self._open_connection_to_missing_shard, connection.shard_id)
591590
else:
@@ -652,7 +651,8 @@ def disable_advanced_shard_aware(self, secs):
652651
self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until)
653652

654653
def _get_shard_aware_endpoint(self):
655-
if self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time():
654+
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time()) or \
655+
self._session.cluster.shard_aware_options.disable_shardaware_port:
656656
return None
657657

658658
endpoint = None
@@ -820,7 +820,7 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
820820
return
821821

822822
for shard_id in range(self.host.sharding_info.shards_count):
823-
if skip_shard_id and skip_shard_id == shard_id:
823+
if skip_shard_id is not None and skip_shard_id == shard_id:
824824
continue
825825
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
826826
if isinstance(future, Future):

docs/scylla_specific.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,25 @@ https://github.com/scylladb/scylla/blob/master/docs/design-notes/protocols.md#cq
2626
New Cluster Helpers
2727
-------------------
2828

29+
* ``shard_aware_options``
30+
31+
Setting it to ``dict(disable=True)`` would disable the shard aware functionally, for cases favoring once connection per host (example, lots of processes connecting from one client host, generating a big load of connections
32+
33+
Other option is to configure scylla by setting ``enable_shard_aware_drivers: false`` on scylla.yaml.
34+
35+
.. code:: python
36+
37+
from cassandra.cluster import Cluster
38+
39+
cluster = Cluster(shard_aware_options=dict(disable=True))
40+
session = cluster.connect()
41+
42+
assert not cluster.is_shard_aware(), "Shard aware should be disabled"
43+
44+
# or just disable the shard aware port logic
45+
cluster = Cluster(shard_aware_options=dict(disable_shardaware_port=True))
46+
session = cluster.connect()
47+
2948
* ``cluster.is_shard_aware()``
3049

3150
New method available on ``Cluster`` allowing to check whether the remote cluster supports shard awareness (bool)

tests/unit/test_control_connection.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ def on_up(self, host):
100100
def on_down(self, host, is_host_addition):
101101
self.down_host = host
102102

103-
def get_control_connection_host(self):
104-
return self.added_hosts[0] if self.added_hosts else None
105-
106103

107104
def _node_meta_results(local_results, peer_results):
108105
"""

tests/unit/test_host_connection_pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from mock import Mock, NonCallableMagicMock, MagicMock
2727
from threading import Thread, Event, Lock
2828

29-
from cassandra.cluster import Session
29+
from cassandra.cluster import Session, ShardAwareOptions
3030
from cassandra.connection import Connection
3131
from cassandra.pool import HostConnection, HostConnectionPool
3232
from cassandra.pool import Host, NoConnectionsAvailable
@@ -160,6 +160,7 @@ def test_return_defunct_connection_on_down_host(self):
160160
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
161161
max_request_id=100, signaled_error=False)
162162
session.cluster.connection_factory.return_value = conn
163+
session.cluster.shard_aware_options = ShardAwareOptions()
163164

164165
pool = self.PoolImpl(host, HostDistance.LOCAL, session)
165166
session.cluster.connection_factory.assert_called_once_with(host.endpoint, owning_pool=pool)

tests/unit/test_shard_aware.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest.mock import MagicMock
2222
from futures.thread import ThreadPoolExecutor
2323

24+
from cassandra.cluster import ShardAwareOptions
2425
from cassandra.pool import HostConnection, HostDistance
2526
from cassandra.connection import ShardingInfo, DefaultEndPoint
2627
from cassandra.metadata import Murmur3Token
@@ -67,10 +68,11 @@ def __init__(self, is_ssl=False, *args, **kwargs):
6768
self.cluster.ssl_options = {'some_ssl_options': True}
6869
else:
6970
self.cluster.ssl_options = None
71+
self.cluster.shard_aware_options = ShardAwareOptions()
7072
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
7173
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7274
self.cluster.connection_factory = self.mock_connection_factory
73-
self.connection_counter = -1
75+
self.connection_counter = 0
7476
self.futures = []
7577

7678
def submit(self, fn, *args, **kwargs):
@@ -87,7 +89,8 @@ def mock_connection_factory(self, *args, **kwargs):
8789
connection.is_closed = False
8890
connection.orphaned_threshold_reached = False
8991
connection.endpoint = args[0]
90-
connection.shard_id = kwargs.get('shard_id', 0)
92+
connection.shard_id = kwargs.get('shard_id', self.connection_counter)
93+
self.connection_counter += 1
9194
connection.sharding_info = ShardingInfo(shard_id=1, shards_count=4,
9295
partitioner="", sharding_algorithm="", sharding_ignore_msb=0,
9396
shard_aware_port=19042, shard_aware_port_ssl=19045)

0 commit comments

Comments
 (0)