Skip to content

Commit 9a99da7

Browse files
Copilotdkropachev
andcommitted
Fix groupby issue in DCAwareRoundRobinPolicy and RackAwareRoundRobinPolicy
- Fixed DCAwareRoundRobinPolicy.populate() to properly collect all hosts per DC - Fixed RackAwareRoundRobinPolicy.populate() to properly collect all hosts per DC/rack - Added regression tests for both policies to verify interleaved hosts are not lost - All existing tests pass (83 tests in test_policies.py) Co-authored-by: dkropachev <[email protected]>
1 parent 38f79eb commit 9a99da7

File tree

2 files changed

+147
-6
lines changed

2 files changed

+147
-6
lines changed

cassandra/policies.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,17 @@ def _dc(self, host):
253253
return host.datacenter or self.local_dc
254254

255255
def populate(self, cluster, hosts):
256-
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
257-
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
256+
# Group hosts by datacenter without relying on groupby which only groups consecutive items
257+
dc_hosts_dict = {}
258+
for host in hosts:
259+
dc = self._dc(host)
260+
if dc not in dc_hosts_dict:
261+
dc_hosts_dict[dc] = []
262+
dc_hosts_dict[dc].append(host)
263+
264+
# Convert lists to tuples with unique hosts
265+
for dc, host_list in dc_hosts_dict.items():
266+
self._dc_live_hosts[dc] = tuple(set(host_list))
258267

259268
if not self.local_dc:
260269
self._endpoints = [
@@ -373,10 +382,31 @@ def _dc(self, host):
373382
return host.datacenter or self.local_dc
374383

375384
def populate(self, cluster, hosts):
376-
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
377-
self._live_hosts[(dc, rack)] = tuple(set(rack_hosts))
378-
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
379-
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
385+
# Group hosts by (dc, rack) and by dc without relying on groupby which only groups consecutive items
386+
rack_hosts_dict = {}
387+
dc_hosts_dict = {}
388+
389+
for host in hosts:
390+
dc = self._dc(host)
391+
rack = self._rack(host)
392+
393+
# Group by (dc, rack)
394+
key = (dc, rack)
395+
if key not in rack_hosts_dict:
396+
rack_hosts_dict[key] = []
397+
rack_hosts_dict[key].append(host)
398+
399+
# Group by dc
400+
if dc not in dc_hosts_dict:
401+
dc_hosts_dict[dc] = []
402+
dc_hosts_dict[dc].append(host)
403+
404+
# Convert lists to tuples with unique hosts
405+
for key, host_list in rack_hosts_dict.items():
406+
self._live_hosts[key] = tuple(set(host_list))
407+
408+
for dc, host_list in dc_hosts_dict.items():
409+
self._dc_live_hosts[dc] = tuple(set(host_list))
380410

381411
self._position = randint(0, len(hosts) - 1) if hosts else 0
382412

tests/unit/test_policies.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def test_wrong_dc(self, policy_specialization, constructor_args):
531531
qplan = list(policy.make_query_plan())
532532
assert len(qplan) == 0
533533

534+
534535
class DCAwareRoundRobinPolicyTest(unittest.TestCase):
535536

536537
def test_default_dc(self):
@@ -573,6 +574,116 @@ def test_default_dc(self):
573574
policy.on_add(host_remote)
574575
assert not policy.local_dc
575576

577+
def test_populate_with_interleaved_dcs(self):
578+
"""Test that DCAwareRoundRobinPolicy doesn't lose hosts when DCs are interleaved.
579+
580+
This is a regression test for the issue where groupby only groups consecutive items,
581+
which caused hosts to be lost when datacenters were not sorted.
582+
"""
583+
# Create hosts with interleaved datacenters (dc1, dc2, dc1, dc2)
584+
hosts = []
585+
586+
# DC1 host 1
587+
h1 = Host(DefaultEndPoint("10.0.0.1"), SimpleConvictionPolicy)
588+
h1.set_location_info("dc1", "rack1")
589+
hosts.append(h1)
590+
591+
# DC2 host 1
592+
h2 = Host(DefaultEndPoint("10.0.0.2"), SimpleConvictionPolicy)
593+
h2.set_location_info("dc2", "rack1")
594+
hosts.append(h2)
595+
596+
# DC1 host 2 (interleaved)
597+
h3 = Host(DefaultEndPoint("10.0.0.3"), SimpleConvictionPolicy)
598+
h3.set_location_info("dc1", "rack1")
599+
hosts.append(h3)
600+
601+
# DC2 host 2 (interleaved)
602+
h4 = Host(DefaultEndPoint("10.0.0.4"), SimpleConvictionPolicy)
603+
h4.set_location_info("dc2", "rack1")
604+
hosts.append(h4)
605+
606+
policy = DCAwareRoundRobinPolicy("dc1")
607+
policy.populate(Mock(), hosts)
608+
609+
# Check that all hosts are registered
610+
dc1_hosts = policy._dc_live_hosts.get("dc1", ())
611+
dc2_hosts = policy._dc_live_hosts.get("dc2", ())
612+
613+
assert len(dc1_hosts) == 2, "DC1 should have 2 hosts"
614+
assert len(dc2_hosts) == 2, "DC2 should have 2 hosts"
615+
assert h1 in dc1_hosts
616+
assert h3 in dc1_hosts
617+
assert h2 in dc2_hosts
618+
assert h4 in dc2_hosts
619+
620+
621+
class RackAwareRoundRobinPolicyTest(unittest.TestCase):
622+
623+
def test_populate_with_interleaved_racks(self):
624+
"""Test that RackAwareRoundRobinPolicy doesn't lose hosts when racks are interleaved.
625+
626+
This is a regression test for the issue where groupby only groups consecutive items,
627+
which caused hosts to be lost when racks/datacenters were not sorted.
628+
"""
629+
# Create hosts with interleaved racks
630+
hosts = []
631+
632+
# DC1 Rack1 host 1
633+
h1 = Host(DefaultEndPoint("10.0.0.1"), SimpleConvictionPolicy)
634+
h1.set_location_info("dc1", "rack1")
635+
hosts.append(h1)
636+
637+
# DC1 Rack2 host 1
638+
h2 = Host(DefaultEndPoint("10.0.0.2"), SimpleConvictionPolicy)
639+
h2.set_location_info("dc1", "rack2")
640+
hosts.append(h2)
641+
642+
# DC1 Rack1 host 2 (interleaved)
643+
h3 = Host(DefaultEndPoint("10.0.0.3"), SimpleConvictionPolicy)
644+
h3.set_location_info("dc1", "rack1")
645+
hosts.append(h3)
646+
647+
# DC1 Rack2 host 2 (interleaved)
648+
h4 = Host(DefaultEndPoint("10.0.0.4"), SimpleConvictionPolicy)
649+
h4.set_location_info("dc1", "rack2")
650+
hosts.append(h4)
651+
652+
# DC2 Rack1 host 1
653+
h5 = Host(DefaultEndPoint("10.0.0.5"), SimpleConvictionPolicy)
654+
h5.set_location_info("dc2", "rack1")
655+
hosts.append(h5)
656+
657+
# DC1 Rack1 host 3 (interleaved again)
658+
h6 = Host(DefaultEndPoint("10.0.0.6"), SimpleConvictionPolicy)
659+
h6.set_location_info("dc1", "rack1")
660+
hosts.append(h6)
661+
662+
policy = RackAwareRoundRobinPolicy("dc1", "rack1")
663+
policy.populate(Mock(), hosts)
664+
665+
# Check that all hosts are registered
666+
dc1_rack1_hosts = policy._live_hosts.get(("dc1", "rack1"), ())
667+
dc1_rack2_hosts = policy._live_hosts.get(("dc1", "rack2"), ())
668+
dc2_rack1_hosts = policy._live_hosts.get(("dc2", "rack1"), ())
669+
670+
dc1_hosts = policy._dc_live_hosts.get("dc1", ())
671+
dc2_hosts = policy._dc_live_hosts.get("dc2", ())
672+
673+
assert len(dc1_rack1_hosts) == 3, "DC1 Rack1 should have 3 hosts"
674+
assert len(dc1_rack2_hosts) == 2, "DC1 Rack2 should have 2 hosts"
675+
assert len(dc2_rack1_hosts) == 1, "DC2 Rack1 should have 1 host"
676+
assert len(dc1_hosts) == 5, "DC1 should have 5 hosts total"
677+
assert len(dc2_hosts) == 1, "DC2 should have 1 host total"
678+
679+
assert h1 in dc1_rack1_hosts
680+
assert h3 in dc1_rack1_hosts
681+
assert h6 in dc1_rack1_hosts
682+
assert h2 in dc1_rack2_hosts
683+
assert h4 in dc1_rack2_hosts
684+
assert h5 in dc2_rack1_hosts
685+
686+
576687
class TokenAwarePolicyTest(unittest.TestCase):
577688

578689
def test_wrap_round_robin(self):

0 commit comments

Comments
 (0)