Skip to content

Commit c62665f

Browse files
Add RackAwareRoundRobinPolicy for host selection
1 parent 1c3cff8 commit c62665f

File tree

6 files changed

+369
-84
lines changed

6 files changed

+369
-84
lines changed

cassandra/cluster.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ def _profiles_without_explicit_lbps(self):
492492

493493
def distance(self, host):
494494
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
495-
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
495+
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
496+
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
496497
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
497498
HostDistance.IGNORED
498499

@@ -609,7 +610,7 @@ class Cluster(object):
609610
610611
Defaults to loopback interface.
611612
612-
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
613+
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
613614
local_dc set (as is the default), the DC is chosen from an arbitrary
614615
host in contact_points. In this case, contact_points should contain
615616
only nodes from a single, local DC.
@@ -1369,21 +1370,25 @@ def __init__(self,
13691370
self._user_types = defaultdict(dict)
13701371

13711372
self._min_requests_per_connection = {
1373+
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
13721374
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
13731375
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
13741376
}
13751377

13761378
self._max_requests_per_connection = {
1379+
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
13771380
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
13781381
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
13791382
}
13801383

13811384
self._core_connections_per_host = {
1385+
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13821386
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
13831387
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
13841388
}
13851389

13861390
self._max_connections_per_host = {
1391+
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13871392
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
13881393
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
13891394
}

cassandra/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3436,7 +3436,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
34363436
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
34373437
# First check if there are local replicas
34383438
valid_replicas = [host for host in all_replicas if
3439-
host.is_up and distance(host) == HostDistance.LOCAL]
3439+
host.is_up and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]]
34403440
if not valid_replicas:
34413441
valid_replicas = [host for host in all_replicas if host.is_up]
34423442

cassandra/policies.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ class HostDistance(object):
4646
connections opened to it.
4747
"""
4848

49-
LOCAL = 0
49+
LOCAL_RACK = 0
50+
"""
51+
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
52+
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
53+
and will have a greater number of connections opened against
54+
them by default.
55+
56+
This distance is typically used for nodes within the same
57+
datacenter and the same rack as the client.
58+
"""
59+
60+
LOCAL = 1
5061
"""
5162
Nodes with ``LOCAL`` distance will be preferred for operations
5263
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
@@ -57,12 +68,12 @@ class HostDistance(object):
5768
datacenter as the client.
5869
"""
5970

60-
REMOTE = 1
71+
REMOTE = 2
6172
"""
6273
Nodes with ``REMOTE`` distance will be treated as a last resort
63-
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
64-
and will have a smaller number of connections opened against
65-
them by default.
74+
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
75+
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
76+
connections opened against them by default.
6677
6778
This distance is typically used for nodes outside of the
6879
datacenter that the client is running in.
@@ -102,6 +113,11 @@ class LoadBalancingPolicy(HostStateListener):
102113
103114
You may also use subclasses of :class:`.LoadBalancingPolicy` for
104115
custom behavior.
116+
117+
You should always use immutable collections (e.g., tuples or
118+
frozensets) to store information about hosts to prevent accidental
119+
modification. When there are changes to the hosts (e.g., a host is
120+
down or up), the old collection should be replaced with a new one.
105121
"""
106122

107123
_hosts_lock = None
@@ -316,6 +332,130 @@ def on_add(self, host):
316332
def on_remove(self, host):
317333
self.on_down(host)
318334

335+
class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
336+
"""
337+
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
338+
in the local rack, before hosts in the local datacenter but a
339+
different rack, before hosts in all other datercentres
340+
"""
341+
342+
local_dc = None
343+
local_rack = None
344+
used_hosts_per_remote_dc = 0
345+
346+
def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
347+
"""
348+
The `local_dc` and `local_rack` parameters should be the name of the
349+
datacenter and rack (such as is reported by ``nodetool ring``) that
350+
should be considered local.
351+
352+
`used_hosts_per_remote_dc` controls how many nodes in
353+
each remote datacenter will have connections opened
354+
against them. In other words, `used_hosts_per_remote_dc` hosts
355+
will be considered :attr:`~.HostDistance.REMOTE` and the
356+
rest will be considered :attr:`~.HostDistance.IGNORED`.
357+
By default, all remote hosts are ignored.
358+
"""
359+
self.local_rack = local_rack
360+
self.local_dc = local_dc
361+
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
362+
self._live_hosts = {}
363+
self._dc_live_hosts = {}
364+
self._endpoints = []
365+
self._position = 0
366+
LoadBalancingPolicy.__init__(self)
367+
368+
def _rack(self, host):
369+
return host.rack or self.local_rack
370+
371+
def _dc(self, host):
372+
return host.datacenter or self.local_dc
373+
374+
def populate(self, cluster, hosts):
375+
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
376+
self._live_hosts[(dc, rack)] = tuple(set(rack_hosts))
377+
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
378+
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
379+
380+
self._position = randint(0, len(hosts) - 1) if hosts else 0
381+
382+
def distance(self, host):
383+
rack = self._rack(host)
384+
dc = self._dc(host)
385+
if rack == self.local_rack and dc == self.local_dc:
386+
return HostDistance.LOCAL_RACK
387+
388+
if dc == self.local_dc:
389+
return HostDistance.LOCAL
390+
391+
if not self.used_hosts_per_remote_dc:
392+
return HostDistance.IGNORED
393+
394+
dc_hosts = self._dc_live_hosts.get(dc, ())
395+
if not dc_hosts:
396+
return HostDistance.IGNORED
397+
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
398+
return HostDistance.REMOTE
399+
else:
400+
return HostDistance.IGNORED
401+
402+
def make_query_plan(self, working_keyspace=None, query=None):
403+
pos = self._position
404+
self._position += 1
405+
406+
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
407+
pos = (pos % len(local_rack_live)) if local_rack_live else 0
408+
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
409+
# This ensures we get exactly one full cycle starting from pos
410+
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
411+
yield host
412+
413+
local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
414+
pos = (pos % len(local_live)) if local_live else 0
415+
for host in islice(cycle(local_live), pos, pos + len(local_live)):
416+
yield host
417+
418+
# the dict can change, so get candidate DCs iterating over keys of a copy
419+
for dc, remote_live in self._dc_live_hosts.copy().items():
420+
if dc != self.local_dc:
421+
for host in remote_live[:self.used_hosts_per_remote_dc]:
422+
yield host
423+
424+
def on_up(self, host):
425+
dc = self._dc(host)
426+
rack = self._rack(host)
427+
with self._hosts_lock:
428+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
429+
if host not in current_rack_hosts:
430+
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
431+
current_dc_hosts = self._dc_live_hosts.get(dc, ())
432+
if host not in current_dc_hosts:
433+
self._dc_live_hosts[dc] = current_dc_hosts + (host, )
434+
435+
def on_down(self, host):
436+
dc = self._dc(host)
437+
rack = self._rack(host)
438+
with self._hosts_lock:
439+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
440+
if host in current_rack_hosts:
441+
hosts = tuple(h for h in current_rack_hosts if h != host)
442+
if hosts:
443+
self._live_hosts[(dc, rack)] = hosts
444+
else:
445+
del self._live_hosts[(dc, rack)]
446+
current_dc_hosts = self._dc_live_hosts.get(dc, ())
447+
if host in current_dc_hosts:
448+
hosts = tuple(h for h in current_dc_hosts if h != host)
449+
if hosts:
450+
self._dc_live_hosts[dc] = hosts
451+
else:
452+
del self._dc_live_hosts[dc]
453+
454+
def on_add(self, host):
455+
self.on_up(host)
456+
457+
def on_remove(self, host):
458+
self.on_down(host)
319459

320460
class TokenAwarePolicy(LoadBalancingPolicy):
321461
"""
@@ -390,7 +530,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
390530
shuffle(replicas)
391531

392532
for replica in replicas:
393-
if replica.is_up and child.distance(replica) == HostDistance.LOCAL:
533+
if replica.is_up and child.distance(replica) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]:
394534
yield replica
395535

396536
for host in child.make_query_plan(keyspace, query):

docs/api/cassandra/policies.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ Load Balancing
1818
.. autoclass:: DCAwareRoundRobinPolicy
1919
:members:
2020

21+
.. autoclass:: RackAwareRoundRobinPolicy
22+
:members:
23+
2124
.. autoclass:: WhiteListRoundRobinPolicy
2225
:members:
2326

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import logging
2+
import unittest
3+
4+
from cassandra.cluster import Cluster
5+
from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy
6+
7+
from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
def setup_module():
12+
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 3}})
13+
14+
class RackAwareRoundRobinPolicyTests(unittest.TestCase):
15+
@classmethod
16+
def setup_class(cls):
17+
cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION,
18+
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0),
19+
reconnection_policy=ConstantReconnectionPolicy(1))
20+
cls.session = cls.cluster.connect()
21+
cls.create_ks_and_cf(cls)
22+
cls.create_data(cls.session)
23+
cls.node1, cls.node2, cls.node3, cls.node4, cls.node5, cls.node6, cls.node7 = get_cluster().nodes.values()
24+
25+
@classmethod
26+
def teardown_class(cls):
27+
cls.cluster.shutdown()
28+
29+
def create_ks_and_cf(self):
30+
self.session.execute(
31+
"""
32+
DROP KEYSPACE IF EXISTS test1
33+
"""
34+
)
35+
self.session.execute(
36+
"""
37+
CREATE KEYSPACE test1
38+
WITH replication = {
39+
'class': 'NetworkTopologyStrategy',
40+
'replication_factor': 3
41+
}
42+
""")
43+
44+
self.session.execute(
45+
"""
46+
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
47+
""")
48+
49+
@staticmethod
50+
def create_data(session):
51+
prepared = session.prepare(
52+
"""
53+
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
54+
""")
55+
56+
for i in range(50):
57+
bound = prepared.bind((i, i%5, i%2))
58+
session.execute(bound)
59+
60+
def test_rack_aware(self):
61+
prepared = self.session.prepare(
62+
"""
63+
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
64+
""")
65+
66+
for i in range (10):
67+
bound = prepared.bind([i])
68+
results = self.session.execute(bound)
69+
self.assertEqual(results, [(i, i%5, i%2)])
70+
coordinator = str(results.response_future.coordinator_host.endpoint)
71+
self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"]))
72+
73+
self.node2.stop(wait_other_notice=True, gently=True)
74+
75+
for i in range (10):
76+
bound = prepared.bind([i])
77+
results = self.session.execute(bound)
78+
self.assertEqual(results, [(i, i%5, i%2)])
79+
coordinator =str(results.response_future.coordinator_host.endpoint)
80+
self.assertEqual(coordinator, "127.0.0.1:9042")
81+
82+
self.node1.stop(wait_other_notice=True, gently=True)
83+
84+
for i in range (10):
85+
bound = prepared.bind([i])
86+
results = self.session.execute(bound)
87+
self.assertEqual(results, [(i, i%5, i%2)])
88+
coordinator = str(results.response_future.coordinator_host.endpoint)
89+
self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"]))

0 commit comments

Comments
 (0)