Skip to content

Commit 6eaafc3

Browse files
committed
shard aware: shard aware unique port (advenced shard aware)
shard aware port in now advertised OPTIONS messge, and we need to replace the connection with the new host/port * fixing tests to match the advenced shard awareness now that we could have two host listed (one with 9042 port, and one with 19042), we need to make the test a bit less prune to failure cause of that change
1 parent 6d6e19e commit 6eaafc3

14 files changed

+196
-32
lines changed

cassandra/c_shard_info.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@ cdef class ShardingInfo():
2222
cdef readonly str partitioner
2323
cdef readonly str sharding_algorithm
2424
cdef readonly int sharding_ignore_msb
25+
cdef readonly int shard_aware_port
26+
cdef readonly int shard_aware_port_ssl
2527

2628
cdef object __weakref__
2729

28-
def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb):
30+
def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb, shard_aware_port,
31+
shard_aware_port_ssl):
2932
self.shards_count = int(shards_count)
3033
self.partitioner = partitioner
3134
self.sharding_algorithm = sharding_algorithm
3235
self.sharding_ignore_msb = int(sharding_ignore_msb)
33-
36+
self.shard_aware_port = int(shard_aware_port) if shard_aware_port else 0
37+
self.shard_aware_port_ssl = int(shard_aware_port_ssl) if shard_aware_port_ssl else 0
3438

3539
@staticmethod
3640
def parse_sharding_info(message):
@@ -39,12 +43,15 @@ cdef class ShardingInfo():
3943
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
4044
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
4145
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
46+
shard_aware_port = message.options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
47+
shard_aware_port_ssl = message.options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
4248

4349
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
4450
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
4551
return 0, None
4652

47-
return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb)
53+
return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
54+
shard_aware_port, shard_aware_port_ssl)
4855

4956

5057
def shard_id_from_token(self, int64_t token_input):

cassandra/cluster.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,14 +1734,20 @@ def get_connection_holders(self):
17341734
holders.append(self.control_connection)
17351735
return holders
17361736

1737+
def get_all_pools(self):
1738+
pools = []
1739+
for s in tuple(self.sessions):
1740+
pools.extend(s.get_pools())
1741+
return pools
1742+
17371743
def is_shard_aware(self):
1738-
return bool(self.get_connection_holders()[:-1][0].host.sharding_info)
1744+
return bool(self.get_all_pools()[0].host.sharding_info)
17391745

17401746
def shard_aware_stats(self):
17411747
if self.is_shard_aware():
17421748
return {str(pool.host.endpoint): {'shards_count': pool.host.sharding_info.shards_count,
17431749
'connected': len(pool._connections.keys())}
1744-
for pool in self.get_connection_holders()[:-1]}
1750+
for pool in self.get_all_pools()}
17451751

17461752
def shutdown(self):
17471753
"""
@@ -3756,7 +3762,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
37563762
partitioner = local_row.get("partitioner")
37573763
tokens = local_row.get("tokens")
37583764

3759-
host = self._cluster.metadata.get_host(connection.endpoint)
3765+
host = self._cluster.metadata.get_host(connection.original_endpoint)
37603766
if host:
37613767
datacenter = local_row.get("data_center")
37623768
rack = local_row.get("rack")
@@ -4049,9 +4055,8 @@ def _get_peers_query(self, peers_query_type, connection=None):
40494055
query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
40504056
if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
40514057
else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
4052-
4053-
host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version
4054-
host_dse_version = self._cluster.metadata.get_host(connection.endpoint).dse_version
4058+
host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version
4059+
host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version
40554060
uses_native_address_query = (
40564061
host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)
40574062

cassandra/connection.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
import time
2929
import ssl
3030
import weakref
31-
31+
import random
32+
import itertools
3233

3334
if 'gevent.monkey' in sys.modules:
3435
from gevent.queue import Queue, Empty
@@ -116,6 +117,10 @@ def decompress(byts):
116117
HEADER_DIRECTION_TO_CLIENT = 0x80
117118
HEADER_DIRECTION_MASK = 0x80
118119

120+
# shard aware default for opening per shard connection
121+
DEFAULT_LOCAL_PORT_LOW = 49152
122+
DEFAULT_LOCAL_PORT_HIGH = 65535
123+
119124
frame_header_v1_v2 = struct.Struct('>BbBi')
120125
frame_header_v3 = struct.Struct('>BhBi')
121126

@@ -666,6 +671,17 @@ def reset_cql_frame_buffer(self):
666671
self.reset_io_buffer()
667672

668673

674+
class ShardawarePortGenerator:
675+
@classmethod
676+
def generate(cls, shard_id, total_shards):
677+
start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
678+
available_ports = itertools.chain(range(start, DEFAULT_LOCAL_PORT_HIGH), range(DEFAULT_LOCAL_PORT_LOW, start))
679+
680+
for port in available_ports:
681+
if port % total_shards == shard_id:
682+
yield port
683+
684+
669685
class Connection(object):
670686

671687
CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -762,7 +778,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
762778
ssl_options=None, sockopts=None, compression=True,
763779
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
764780
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
765-
ssl_context=None, owning_pool=None):
781+
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None):
766782

767783
# TODO next major rename host to endpoint and remove port kwarg.
768784
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
@@ -812,6 +828,9 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
812828

813829
self.lock = RLock()
814830
self.connected_event = Event()
831+
self.shard_id = shard_id
832+
self.total_shards = total_shards
833+
self.original_endpoint = self.endpoint
815834

816835
@property
817836
def host(self):
@@ -874,6 +893,15 @@ def _wrap_socket_from_context(self):
874893
self._socket = self.ssl_context.wrap_socket(self._socket, **ssl_options)
875894

876895
def _initiate_connection(self, sockaddr):
896+
if self.shard_id is not None:
897+
for port in ShardawarePortGenerator.generate(self.shard_id, self.total_shards):
898+
try:
899+
self._socket.bind(('', port))
900+
break
901+
except Exception as ex:
902+
log.debug("port=%d couldn't bind cause: %s", port, str(ex))
903+
log.debug(f'connection (%r) port=%d should be shard_id=%d', id(self), port, port % self.total_shards)
904+
877905
self._socket.connect(sockaddr)
878906

879907
def _match_hostname(self):
@@ -894,6 +922,7 @@ def _get_socket_addresses(self):
894922
def _connect_socket(self):
895923
sockerr = None
896924
addresses = self._get_socket_addresses()
925+
port = None
897926
for (af, socktype, proto, _, sockaddr) in addresses:
898927
try:
899928
self._socket = self._socket_impl.socket(af, socktype, proto)

cassandra/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def export_schema_as_string(self):
134134

135135
def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs):
136136

137-
server_version = self.get_host(connection.endpoint).release_version
138-
dse_version = self.get_host(connection.endpoint).dse_version
137+
server_version = self.get_host(connection.original_endpoint).release_version
138+
dse_version = self.get_host(connection.original_endpoint).dse_version
139139
parser = get_schema_parser(connection, server_version, dse_version, timeout)
140140

141141
if not target_type:

cassandra/pool.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import socket
2222
import time
2323
import random
24+
import copy
2425
from threading import Lock, RLock, Condition
2526
import weakref
2627
try:
@@ -412,6 +413,7 @@ def __init__(self, host, host_distance, session):
412413
# so that we can dispose of them.
413414
self._trash = set()
414415
self._shard_connections_futures = []
416+
self.advanced_shardaware_block_until = 0
415417

416418
if host_distance == HostDistance.IGNORED:
417419
log.debug("Not opening connection to ignored host %s", self.host)
@@ -431,7 +433,7 @@ def __init__(self, host, host_distance, session):
431433

432434
if first_connection.sharding_info:
433435
self.host.sharding_info = first_connection.sharding_info
434-
self._open_connections_for_all_shards()
436+
self._open_connections_for_all_shards(first_connection.shard_id)
435437

436438
log.debug("Finished initializing connection for host %s", self.host)
437439

@@ -645,6 +647,24 @@ def _close_excess_connections(self):
645647
log.debug("Closing excess connection (%s) to %s", id(c), self.host)
646648
c.close()
647649

650+
def disable_advanced_shard_aware(self, secs):
651+
log.warning("disabling advanced_shard_aware for %i seconds, could be that this client is behind NAT?", secs)
652+
self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until)
653+
654+
def _get_shard_aware_endpoint(self):
655+
if self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time():
656+
return None
657+
658+
endpoint = None
659+
if self._session.cluster.ssl_options and self.host.sharding_info.shard_aware_port_ssl:
660+
endpoint = copy.copy(self.host.endpoint)
661+
endpoint._port = self.host.sharding_info.shard_aware_port_ssl
662+
elif self.host.sharding_info.shard_aware_port:
663+
endpoint = copy.copy(self.host.endpoint)
664+
endpoint._port = self.host.sharding_info.shard_aware_port
665+
666+
return endpoint
667+
648668
def _open_connection_to_missing_shard(self, shard_id):
649669
"""
650670
Creates a new connection, checks its shard_id and populates our shard
@@ -666,13 +686,28 @@ def _open_connection_to_missing_shard(self, shard_id):
666686
with self._lock:
667687
if self.is_shutdown:
668688
return
689+
shard_aware_endpoint = self._get_shard_aware_endpoint()
690+
log.debug("shard_aware_endpoint=%r", shard_aware_endpoint)
691+
692+
if shard_aware_endpoint:
693+
conn = self._session.cluster.connection_factory(shard_aware_endpoint, owning_pool=self,
694+
shard_id=shard_id,
695+
total_shards=self.host.sharding_info.shards_count)
696+
conn.original_endpoint = self.host.endpoint
697+
else:
698+
conn = self._session.cluster.connection_factory(self.host.endpoint, owning_pool=self)
669699

670-
conn = self._session.cluster.connection_factory(self.host.endpoint)
671700
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.shard_id, self.host)
672701
if self.is_shutdown:
673702
log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn))
674703
conn.close()
675704
return
705+
706+
if shard_aware_endpoint and shard_id != conn.shard_id:
707+
# connection didn't land on expected shared
708+
# assuming behind a NAT, disabling advanced shard aware for a while
709+
self.disable_advanced_shard_aware(10 * 60)
710+
676711
old_conn = self._connections.get(conn.shard_id)
677712
if old_conn is None or old_conn.orphaned_threshold_reached:
678713
log.debug(
@@ -776,7 +811,7 @@ def _open_connection_to_missing_shard(self, shard_id):
776811
conn.close()
777812
self._connecting.discard(shard_id)
778813

779-
def _open_connections_for_all_shards(self):
814+
def _open_connections_for_all_shards(self, skip_shard_id=None):
780815
"""
781816
Loop over all the shards and try to open a connection to each one.
782817
"""
@@ -785,6 +820,8 @@ def _open_connections_for_all_shards(self):
785820
return
786821

787822
for shard_id in range(self.host.sharding_info.shards_count):
823+
if skip_shard_id and skip_shard_id == shard_id:
824+
continue
788825
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
789826
if isinstance(future, Future):
790827
self._connecting.add(shard_id)

cassandra/shard_info.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020

2121
class _ShardingInfo(object):
2222

23-
def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb):
23+
def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb, shard_aware_port, shard_aware_port_ssl):
2424
self.shards_count = int(shards_count)
2525
self.partitioner = partitioner
2626
self.sharding_algorithm = sharding_algorithm
2727
self.sharding_ignore_msb = int(sharding_ignore_msb)
28+
self.shard_aware_port = int(shard_aware_port) if shard_aware_port else None
29+
self.shard_aware_port_ssl = int(shard_aware_port_ssl) if shard_aware_port_ssl else None
2830

2931
@staticmethod
3032
def parse_sharding_info(message):
@@ -33,13 +35,16 @@ def parse_sharding_info(message):
3335
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
3436
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
3537
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
38+
shard_aware_port = message.options.get('SCYLLA_SHARD_AWARE_PORT', [''])[0] or None
39+
shard_aware_port_ssl = message.options.get('SCYLLA_SHARD_AWARE_PORT_SSL', [''])[0] or None
3640
log.debug("Parsing sharding info from message options %s", message.options)
3741

3842
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
3943
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
4044
return 0, None
4145

42-
return int(shard_id), _ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb)
46+
return int(shard_id), _ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb,
47+
shard_aware_port, shard_aware_port_ssl)
4348

4449
def shard_id_from_token(self, token):
4550
"""

docs/scylla_specific.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ Shard Awareness
88
As a result, latency is significantly reduced because there is no need to pass data between the shards.
99

1010
Details on the scylla cql protocol extensions
11-
https://github.com/scylladb/scylla/blob/master/docs/design-notes/protocol-extensions.md
11+
https://github.com/scylladb/scylla/blob/master/docs/design-notes/protocol-extensions.md#intranode-sharding
1212

1313
For using it you only need to enable ``TokenAwarePolicy`` on the ``Cluster``
1414

15+
See the configuration of ``native_shard_aware_transport_port`` and ``native_shard_aware_transport_port_ssl`` on scylla.yaml:
16+
https://github.com/scylladb/scylla/blob/master/docs/design-notes/protocols.md#cql-client-protocol
17+
1518
.. code:: python
1619
1720
from cassandra.cluster import Cluster

tests/integration/standard/test_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1514,7 +1514,7 @@ def test_prepare_on_ignored_hosts(self):
15141514
# the length of mock_calls will vary, but all should use the unignored
15151515
# address
15161516
for c in cluster.connection_factory.mock_calls:
1517-
self.assertEqual(call(DefaultEndPoint(unignored_address)), c)
1517+
self.assertEqual(unignored_address, c.args[0].address)
15181518
cluster.shutdown()
15191519

15201520

tests/integration/standard/test_shard_aware.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import time
1616
import random
1717
from subprocess import run
18+
import logging
1819

1920
try:
2021
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -28,10 +29,12 @@
2829

2930
from cassandra.cluster import Cluster
3031
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy
31-
from cassandra import OperationTimedOut
32+
from cassandra import OperationTimedOut, ConsistencyLevel
3233

3334
from tests.integration import use_cluster, get_node, PROTOCOL_VERSION
3435

36+
LOGGER = logging.getLogger(__name__)
37+
3538

3639
def setup_module():
3740
os.environ['SCYLLA_EXT_OPTS'] = "--smp 4 --memory 2048M"
@@ -41,12 +44,12 @@ def setup_module():
4144
class TestShardAwareIntegration(unittest.TestCase):
4245
@classmethod
4346
def setup_class(cls):
44-
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
47+
cls.cluster = Cluster(contact_points=["127.0.0.1"], protocol_version=PROTOCOL_VERSION,
48+
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
4549
reconnection_policy=ConstantReconnectionPolicy(1))
4650
cls.session = cls.cluster.connect()
47-
48-
print(cls.cluster.is_shard_aware())
49-
print(cls.cluster.shard_aware_stats())
51+
LOGGER.info(cls.cluster.is_shard_aware())
52+
LOGGER.info(cls.cluster.shard_aware_stats())
5053

5154
@classmethod
5255
def teardown_class(cls):
@@ -56,7 +59,7 @@ def verify_same_shard_in_tracing(self, results, shard_name):
5659
traces = results.get_query_trace()
5760
events = traces.events
5861
for event in events:
59-
print(event.thread_name, event.description)
62+
LOGGER.info("%s %s", event.thread_name, event.description)
6063
for event in events:
6164
self.assertEqual(event.thread_name, shard_name)
6265
self.assertIn('querying locally', "\n".join([event.description for event in events]))
@@ -65,7 +68,7 @@ def verify_same_shard_in_tracing(self, results, shard_name):
6568
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
6669
events = [event for event in traces]
6770
for event in events:
68-
print(event.thread, event.activity)
71+
LOGGER.info("%s %s", event.thread, event.activity)
6972
for event in events:
7073
self.assertEqual(event.thread, shard_name)
7174
self.assertIn('querying locally', "\n".join([event.activity for event in events]))

0 commit comments

Comments
 (0)