Skip to content

Commit 02e7ce9

Browse files
Use tablets in token and shard awareness
Add mechanism to parse system.tablets periodically. In TokenAwarePolicy check if keyspace uses tablets if so try to use them to find replicas. Make shard awareness work when using tablets. Everything is wrapped in experimental setting, because tablets are still experimental in ScyllaDB and changes in the tablets format are possible.
1 parent e8d7151 commit 02e7ce9

File tree

8 files changed

+199
-15
lines changed

8 files changed

+199
-15
lines changed

cassandra/cluster.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import weakref
4242
from weakref import WeakValueDictionary
4343

44-
from cassandra import (ConsistencyLevel, AuthenticationFailed,
44+
from cassandra import (ConsistencyLevel, AuthenticationFailed, InvalidRequest,
4545
OperationTimedOut, UnsupportedOperation,
4646
SchemaTargetType, DriverException, ProtocolVersion,
4747
UnresolvableContactPoints)
@@ -51,6 +51,7 @@
5151
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
5252
ContinuousPagingState, SniEndPointFactory, ConnectionBusy)
5353
from cassandra.cqltypes import UserType
54+
import cassandra.cqltypes as types
5455
from cassandra.encoder import Encoder
5556
from cassandra.protocol import (QueryMessage, ResultMessage,
5657
ErrorMessage, ReadTimeoutErrorMessage,
@@ -79,6 +80,7 @@
7980
named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET,
8081
HostTargetingStatement)
8182
from cassandra.marshal import int64_pack
83+
from cassandra.tablets import Tablet, Tablets
8284
from cassandra.timestamps import MonotonicTimestampGenerator
8385
from cassandra.compat import Mapping
8486
from cassandra.util import _resolve_contact_points_to_string_map, Version
@@ -1775,6 +1777,14 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
17751777
self.shutdown()
17761778
raise
17771779

1780+
# Update the information about tablet support after connection handshake.
1781+
self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1
1782+
child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None
1783+
while child_policy is not None:
1784+
if hasattr(child_policy, '_tablet_routing_v1'):
1785+
child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1
1786+
child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None
1787+
17781788
self.profile_manager.check_supported() # todo: rename this method
17791789

17801790
if self.idle_heartbeat_interval:
@@ -2389,7 +2399,6 @@ def add_prepared(self, query_id, prepared_statement):
23892399
with self._prepared_statement_lock:
23902400
self._prepared_statements[query_id] = prepared_statement
23912401

2392-
23932402
class Session(object):
23942403
"""
23952404
A collection of connection pools for each host in the cluster.
@@ -3541,6 +3550,7 @@ class PeersQueryType(object):
35413550
_schema_meta_page_size = 1000
35423551

35433552
_uses_peers_v2 = True
3553+
_tablets_routing_v1 = False
35443554

35453555
# for testing purposes
35463556
_time = time
@@ -3674,6 +3684,8 @@ def _try_connect(self, host):
36743684
# If sharding information is available, it's a ScyllaDB cluster, so do not use peers_v2 table.
36753685
if connection.features.sharding_info is not None:
36763686
self._uses_peers_v2 = False
3687+
3688+
self._tablets_routing_v1 = connection.features.tablets_routing_v1
36773689

36783690
# use weak references in both directions
36793691
# _clear_watcher will be called when this ControlConnection is about to be finalized
@@ -4600,7 +4612,10 @@ def _query(self, host, message=None, cb=None):
46004612
connection = None
46014613
try:
46024614
# TODO get connectTimeout from cluster settings
4603-
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key if self.query else None)
4615+
if self.query:
4616+
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table)
4617+
else:
4618+
connection, request_id = pool.borrow_connection(timeout=2.0)
46044619
self._connection = connection
46054620
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []
46064621

@@ -4719,6 +4734,19 @@ def _set_result(self, host, connection, pool, response):
47194734
self._warnings = getattr(response, 'warnings', None)
47204735
self._custom_payload = getattr(response, 'custom_payload', None)
47214736

4737+
if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload:
4738+
protocol = self.session.cluster.protocol_version
4739+
info = self._custom_payload.get('tablets-routing-v1')
4740+
ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))')
4741+
tablet_routing_info = ctype.from_binary(info, protocol)
4742+
first_token = tablet_routing_info[0]
4743+
last_token = tablet_routing_info[1]
4744+
tablet_replicas = tablet_routing_info[2]
4745+
tablet = Tablet.from_row(first_token, last_token, tablet_replicas)
4746+
keyspace = self.query.keyspace
4747+
table = self.query.table
4748+
self.session.cluster.metadata._tablets.add_tablet(keyspace, table, tablet)
4749+
47224750
if isinstance(response, ResultMessage):
47234751
if response.kind == RESULT_KIND_SET_KEYSPACE:
47244752
session = getattr(self, 'session', None)

cassandra/metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from cassandra.pool import HostDistance
4545
from cassandra.connection import EndPoint
4646
from cassandra.compat import Mapping
47+
from cassandra.tablets import Tablets
4748

4849
log = logging.getLogger(__name__)
4950

@@ -126,6 +127,7 @@ def __init__(self):
126127
self._hosts = {}
127128
self._host_id_by_endpoint = {}
128129
self._hosts_lock = RLock()
130+
self._tablets = Tablets({})
129131

130132
def export_schema_as_string(self):
131133
"""

cassandra/policies.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
335335

336336
_child_policy = None
337337
_cluster_metadata = None
338+
_tablets_routing_v1 = False
338339
shuffle_replicas = False
339340
"""
340341
Yield local replicas in a random order.
@@ -346,6 +347,7 @@ def __init__(self, child_policy, shuffle_replicas=False):
346347

347348
def populate(self, cluster, hosts):
348349
self._cluster_metadata = cluster.metadata
350+
self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1
349351
self._child_policy.populate(cluster, hosts)
350352

351353
def check_supported(self):
@@ -376,7 +378,19 @@ def make_query_plan(self, working_keyspace=None, query=None):
376378
for host in child.make_query_plan(keyspace, query):
377379
yield host
378380
else:
379-
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
381+
replicas = []
382+
if self._tablets_routing_v1:
383+
tablet = self._cluster_metadata._tablets.get_tablet_for_key(keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(routing_key))
384+
385+
if tablet is not None:
386+
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
387+
child_plan = child.make_query_plan(keyspace, query)
388+
389+
replicas = [host for host in child_plan if host.host_id in replicas_mapped]
390+
391+
if replicas == []:
392+
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
393+
380394
if self.shuffle_replicas:
381395
shuffle(replicas)
382396
for replica in replicas:

cassandra/pool.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ class HostConnection(object):
392392
# the number below, all excess connections will be closed.
393393
max_excess_connections_per_shard_multiplier = 3
394394

395+
tablets_routing_v1 = False
396+
395397
def __init__(self, host, host_distance, session):
396398
self.host = host
397399
self.host_distance = host_distance
@@ -436,10 +438,11 @@ def __init__(self, host, host_distance, session):
436438
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
437439
self.host.sharding_info = first_connection.features.sharding_info
438440
self._open_connections_for_all_shards(first_connection.features.shard_id)
441+
self.tablets_routing_v1 = first_connection.features.tablets_routing_v1
439442

440443
log.debug("Finished initializing connection for host %s", self.host)
441444

442-
def _get_connection_for_routing_key(self, routing_key=None):
445+
def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None):
443446
if self.is_shutdown:
444447
raise ConnectionException(
445448
"Pool for %s is shutdown" % (self.host,), self.host)
@@ -450,7 +453,22 @@ def _get_connection_for_routing_key(self, routing_key=None):
450453
shard_id = None
451454
if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key:
452455
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
453-
shard_id = self.host.sharding_info.shard_id_from_token(t.value)
456+
457+
shard_id = None
458+
if self.tablets_routing_v1 and table is not None:
459+
if keyspace is None:
460+
keyspace = self._keyspace
461+
462+
tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)
463+
464+
if tablet is not None:
465+
for replica in tablet.replicas:
466+
if replica[0] == self.host.host_id:
467+
shard_id = replica[1]
468+
break
469+
470+
if shard_id is None:
471+
shard_id = self.host.sharding_info.shard_id_from_token(t.value)
454472

455473
conn = self._connections.get(shard_id)
456474

@@ -496,15 +514,15 @@ def _get_connection_for_routing_key(self, routing_key=None):
496514
return random.choice(active_connections)
497515
return random.choice(list(self._connections.values()))
498516

499-
def borrow_connection(self, timeout, routing_key=None):
500-
conn = self._get_connection_for_routing_key(routing_key)
517+
def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None):
518+
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
501519
start = time.time()
502520
remaining = timeout
503521
last_retry = False
504522
while True:
505523
if conn.is_closed:
506524
# The connection might have been closed in the meantime - if so, try again
507-
conn = self._get_connection_for_routing_key(routing_key)
525+
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
508526
with conn.lock:
509527
if (not conn.is_closed or last_retry) and conn.in_flight < conn.max_request_id:
510528
# On last retry we ignore connection status, since it is better to return closed connection than

cassandra/query.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ class Statement(object):
253253
.. versionadded:: 2.1.3
254254
"""
255255

256+
table = None
257+
"""
258+
The string name of the table this query acts on. This is used when the tablet
259+
experimental feature is enabled and in the same time :class`~.TokenAwarePolicy`
260+
is configured in the profile load balancing policy.
261+
"""
262+
256263
custom_payload = None
257264
"""
258265
:ref:`custom_payload` to be passed to the server.
@@ -272,7 +279,7 @@ class Statement(object):
272279

273280
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
274281
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
275-
is_idempotent=False):
282+
is_idempotent=False, table=None):
276283
if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
277284
raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
278285
if retry_policy is not None:
@@ -286,6 +293,8 @@ def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
286293
self.fetch_size = fetch_size
287294
if keyspace is not None:
288295
self.keyspace = keyspace
296+
if table is not None:
297+
self.table = table
289298
if custom_payload is not None:
290299
self.custom_payload = custom_payload
291300
self.is_idempotent = is_idempotent
@@ -548,6 +557,7 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None
548557
meta = prepared_statement.column_metadata
549558
if meta:
550559
self.keyspace = meta[0].keyspace_name
560+
self.table = meta[0].table_name
551561

552562
Statement.__init__(self, retry_policy, consistency_level, routing_key,
553563
serial_consistency_level, fetch_size, keyspace, custom_payload,

cassandra/tablets.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Experimental, this interface and use may change
2+
from threading import Lock
3+
4+
class Tablet(object):
5+
"""
6+
Represents a single ScyllaDB tablet.
7+
It stores information about each replica, its host and shard,
8+
and the token interval in the format (first_token, last_token].
9+
"""
10+
first_token = 0
11+
last_token = 0
12+
replicas = None
13+
14+
def __init__(self, first_token = 0, last_token = 0, replicas = None):
15+
self.first_token = first_token
16+
self.last_token = last_token
17+
self.replicas = replicas
18+
19+
def __str__(self):
20+
return "<Tablet: first_token=%s last_token=%s replicas=%s>" \
21+
% (self.first_token, self.last_token, self.replicas)
22+
__repr__ = __str__
23+
24+
@staticmethod
25+
def _is_valid_tablet(replicas):
26+
return replicas is not None and len(replicas) != 0
27+
28+
@staticmethod
29+
def from_row(first_token, last_token, replicas):
30+
if Tablet._is_valid_tablet(replicas):
31+
tablet = Tablet(first_token, last_token,replicas)
32+
return tablet
33+
return None
34+
35+
# Experimental, this interface and use may change
36+
class Tablets(object):
37+
_lock = None
38+
_tablets = {}
39+
40+
def __init__(self, tablets):
41+
self._tablets = tablets
42+
self._lock = Lock()
43+
44+
def get_tablet_for_key(self, keyspace, table, t):
45+
tablet = self._tablets.get((keyspace, table), [])
46+
if tablet == []:
47+
return None
48+
49+
id = bisect_left(tablet, t.value, key = lambda tablet: tablet.last_token)
50+
if id < len(tablet) and t.value > tablet[id].first_token:
51+
return tablet[id]
52+
return None
53+
54+
def add_tablet(self, keyspace, table, tablet):
55+
with self._lock:
56+
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
57+
58+
# find first overlaping range
59+
start = bisect_left(tablets_for_table, tablet.first_token, key = lambda t: t.first_token)
60+
if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token:
61+
start = start - 1
62+
63+
# find last overlaping range
64+
end = bisect_left(tablets_for_table, tablet.last_token, key = lambda t: t.last_token)
65+
if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token:
66+
end = end - 1
67+
68+
if start <= end:
69+
del tablets_for_table[start:end + 1]
70+
71+
tablets_for_table.insert(start, tablet)
72+
73+
# bisect.bisect_left implementation from Python 3.11, needed untill support for
74+
# Python < 3.10 is dropped, it is needed to use `key` to extract last_token from
75+
# Tablet list - better solution performance-wise than materialize list of last_tokens
76+
def bisect_left(a, x, lo=0, hi=None, *, key=None):
77+
"""Return the index where to insert item x in list a, assuming a is sorted.
78+
79+
The return value i is such that all e in a[:i] have e < x, and all e in
80+
a[i:] have e >= x. So if x already appears in the list, a.insert(i, x) will
81+
insert just before the leftmost x already there.
82+
83+
Optional args lo (default 0) and hi (default len(a)) bound the
84+
slice of a to be searched.
85+
"""
86+
87+
if lo < 0:
88+
raise ValueError('lo must be non-negative')
89+
if hi is None:
90+
hi = len(a)
91+
# Note, the comparison uses "<" to match the
92+
# __lt__() logic in list.sort() and in heapq.
93+
if key is None:
94+
while lo < hi:
95+
mid = (lo + hi) // 2
96+
if a[mid] < x:
97+
lo = mid + 1
98+
else:
99+
hi = mid
100+
else:
101+
while lo < hi:
102+
mid = (lo + hi) // 2
103+
if key(a[mid]) < x:
104+
lo = mid + 1
105+
else:
106+
hi = mid
107+
return lo

0 commit comments

Comments
 (0)