Skip to content

Commit dc8ac1f

Browse files
committed
Introduce NoDelayShardConnectionBackoffPolicy
This policy is implementation of ShardConnectionBackoffPolicy. It implements same behavior that driver currently has: 1. No delay between creating shard connections 2. It avoids creating multiple connections to same host_id, shard_id
1 parent c89ad2a commit dc8ac1f

File tree

3 files changed

+216
-26
lines changed

3 files changed

+216
-26
lines changed

cassandra/policies.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,69 @@ def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionSche
934934
raise NotImplementedError()
935935

936936

937+
class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
938+
"""
939+
A shard connection backoff policy with no delay between attempts.
940+
Ensures that at most one pending request connection per (host, shard) pair.
941+
If connection attempts for the same (host, shard) it is silently dropped.
942+
"""
943+
944+
def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler:
945+
return _NoDelayShardConnectionBackoffScheduler(scheduler)
946+
947+
948+
class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler):
949+
"""
950+
A scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``.
951+
952+
A shard connection backoff policy with no delay between attempts.
953+
Ensures that at most one pending request connection per (host, shard) pair.
954+
If connection attempts for the same (host, shard) it is silently dropped.
955+
"""
956+
957+
scheduler: _Scheduler
958+
already_scheduled: set[tuple[str, int]]
959+
lock: Lock
960+
is_shutdown: bool = False
961+
962+
def __init__(self, scheduler: _Scheduler):
963+
self.scheduler = scheduler
964+
self.already_scheduled = set()
965+
self.lock = Lock()
966+
967+
def _execute(
968+
self,
969+
host_id: str,
970+
shard_id: int,
971+
method: Callable[[], None],
972+
) -> None:
973+
if self.is_shutdown:
974+
return
975+
try:
976+
method()
977+
finally:
978+
with self.lock:
979+
self.already_scheduled.remove((host_id, shard_id))
980+
981+
def schedule(
982+
self,
983+
host_id: str,
984+
shard_id: int,
985+
method: Callable[[], None],
986+
) -> bool:
987+
with self.lock:
988+
if self.is_shutdown or (host_id, shard_id) in self.already_scheduled:
989+
return False
990+
self.already_scheduled.add((host_id, shard_id))
991+
992+
self.scheduler.schedule(0, self._execute, host_id, shard_id, method)
993+
return True
994+
995+
def shutdown(self):
996+
with self.lock:
997+
self.is_shutdown = True
998+
999+
9371000
class RetryPolicy(object):
9381001
"""
9391002
A policy that describes whether to retry, rethrow, or ignore coordinator

tests/unit/test_policies.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from functools import partial
1617

1718
from itertools import islice, cycle
1819
from unittest.mock import Mock, patch, call
@@ -26,13 +27,15 @@
2627
from cassandra import ConsistencyLevel
2728
from cassandra.cluster import Cluster, ControlConnection
2829
from cassandra.metadata import Metadata
29-
from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy,
30+
from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy,
31+
DCAwareRoundRobinPolicy,
3032
TokenAwarePolicy, SimpleConvictionPolicy,
3133
HostDistance, ExponentialReconnectionPolicy,
3234
RetryPolicy, WriteType,
3335
DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
3436
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
35-
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy)
37+
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy,
38+
ExponentialBackoffRetryPolicy, _NoDelayShardConnectionBackoffScheduler)
3639
from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint
3740
from cassandra.pool import Host
3841
from cassandra.query import Statement
@@ -1579,3 +1582,79 @@ def test_create_whitelist(self):
15791582
# Only the filtered replicas should be allowed
15801583
self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy),
15811584
Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)})
1585+
1586+
1587+
class MockScheduler:
1588+
def __init__(self):
1589+
self.requests = []
1590+
1591+
def schedule(self, delay, fn, *args, **kwargs):
1592+
self.requests.append((delay, fn, args, kwargs))
1593+
1594+
def execute(self):
1595+
for delay, fn, args, kwargs in self.requests:
1596+
fn(*args, **kwargs)
1597+
self.requests = []
1598+
1599+
1600+
class NoDelayShardConnectionBackoffSchedulerTests(unittest.TestCase):
1601+
def test_schedule_executes_method_immediately(self):
1602+
method = Mock()
1603+
scheduler = MockScheduler()
1604+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1605+
1606+
self.assertTrue(policy.schedule('host1', 0, partial(method, 1, 2, key='val')))
1607+
1608+
self.assertEqual(scheduler.requests[0][0], 0)
1609+
scheduler.execute()
1610+
1611+
method.assert_called_once_with(1, 2, key='val')
1612+
1613+
def test_schedule_skips_if_host_shard_already_scheduled(self):
1614+
method = Mock()
1615+
scheduler = MockScheduler()
1616+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1617+
1618+
self.assertTrue(policy.schedule('host1', 0, method))
1619+
self.assertFalse(policy.schedule('host1', 0, method))
1620+
1621+
self.assertEqual(len(scheduler.requests), 1)
1622+
scheduler.execute()
1623+
method.assert_called_once()
1624+
1625+
def test_schedule_does_not_skip_if_shard_is_different(self):
1626+
method = Mock()
1627+
scheduler = MockScheduler()
1628+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1629+
1630+
self.assertTrue(policy.schedule('host1', 0, method))
1631+
self.assertTrue(policy.schedule('host1', 1, method))
1632+
1633+
self.assertEqual(len(scheduler.requests), 2)
1634+
scheduler.execute()
1635+
1636+
self.assertEqual(method.call_count, 2)
1637+
1638+
def test_already_scheduled_resets_after_execution(self):
1639+
method = Mock()
1640+
scheduler = MockScheduler()
1641+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1642+
self.assertTrue(policy.schedule('host1', 0, method))
1643+
1644+
scheduler.execute()
1645+
1646+
self.assertTrue(policy.schedule('host1', 0, method))
1647+
1648+
scheduler.execute()
1649+
1650+
self.assertEqual(method.call_count, 2)
1651+
1652+
def test_schedule_skips_if_shutdown(self):
1653+
method = Mock()
1654+
scheduler = MockScheduler()
1655+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1656+
policy.shutdown()
1657+
1658+
policy.schedule('host1', 0, method)
1659+
1660+
self.assertEqual(len(scheduler.requests), 0)

tests/unit/test_shard_aware.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import uuid
15+
from unittest.mock import Mock
16+
17+
from cassandra.policies import NoDelayShardConnectionBackoffPolicy, _NoDelayShardConnectionBackoffScheduler
1418

1519
try:
1620
import unittest2 as unittest
@@ -21,7 +25,7 @@
2125
from mock import MagicMock
2226
from concurrent.futures import ThreadPoolExecutor
2327

24-
from cassandra.cluster import ShardAwareOptions
28+
from cassandra.cluster import ShardAwareOptions, _Scheduler
2529
from cassandra.pool import HostConnection, HostDistance
2630
from cassandra.connection import ShardingInfo, DefaultEndPoint
2731
from cassandra.metadata import Murmur3Token
@@ -53,11 +57,18 @@ class OptionsHolder(object):
5357
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4)
5458
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2)
5559

56-
def test_advanced_shard_aware_port(self):
60+
def test_shard_aware_reconnection_policy_no_delay(self):
61+
# with NoDelayReconnectionPolicy all the connections should be created right away
62+
self._test_shard_aware_reconnection_policy(4, NoDelayShardConnectionBackoffPolicy(), 4)
63+
64+
def _test_shard_aware_reconnection_policy(self, shard_count, shard_connection_backoff_policy, expected_connections):
5765
"""
5866
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
59-
the next connections would be open using this port
67+
It checks that:
68+
1. Next connections are opened using this port
69+
2. Connection creation pase matches `shard_connection_backoff_policy`
6070
"""
71+
6172
class MockSession(MagicMock):
6273
is_shutdown = False
6374
keyspace = "ks1"
@@ -71,45 +82,82 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7182
self.cluster.ssl_options = None
7283
self.cluster.shard_aware_options = ShardAwareOptions()
7384
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
85+
self._executor_submit_original = self.cluster.executor.submit
86+
self.cluster.executor.submit = self._executor_submit
87+
self.cluster.scheduler = _Scheduler(self.cluster.executor)
88+
89+
# Collect scheduled calls and execute them right away
90+
self.scheduler_calls = []
91+
original_schedule = self.cluster.scheduler.schedule
92+
93+
def new_schedule(delay, fn, *args, **kwargs):
94+
self.scheduler_calls.append((delay, fn, args, kwargs))
95+
return original_schedule(0, fn, *args, **kwargs)
96+
97+
self.cluster.scheduler.schedule = Mock(side_effect=new_schedule)
7498
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7599
self.cluster.connection_factory = self.mock_connection_factory
76100
self.connection_counter = 0
101+
self.shard_connection_backoff_scheduler = shard_connection_backoff_policy.new_connection_scheduler(
102+
self.cluster.scheduler)
77103
self.futures = []
78104

79105
def submit(self, fn, *args, **kwargs):
106+
if self.is_shutdown:
107+
return None
108+
return self.cluster.executor.submit(fn, *args, **kwargs)
109+
110+
def _executor_submit(self, fn, *args, **kwargs):
80111
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
81-
if not self.is_shutdown:
82-
f = self.cluster.executor.submit(fn, *args, **kwargs)
83-
self.futures += [f]
84-
return f
112+
f = self._executor_submit_original(fn, *args, **kwargs)
113+
self.futures += [f]
114+
return f
85115

86116
def mock_connection_factory(self, *args, **kwargs):
87117
connection = MagicMock()
88118
connection.is_shutdown = False
89119
connection.is_defunct = False
90120
connection.is_closed = False
91121
connection.orphaned_threshold_reached = False
92-
connection.endpoint = args[0]
93-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94-
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
122+
connection.endpoint = args[0]
123+
sharding_info = None
124+
if shard_count:
125+
sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="",
126+
sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042,
127+
shard_aware_port_ssl=19045)
128+
connection.features = ProtocolFeatures(
129+
shard_id=kwargs.get('shard_id', self.connection_counter),
130+
sharding_info=sharding_info)
95131
self.connection_counter += 1
96132

97133
return connection
98134

99135
host = MagicMock()
136+
host.host_id = uuid.uuid4()
100137
host.endpoint = DefaultEndPoint("1.2.3.4")
138+
session = None
139+
try:
140+
for port, is_ssl in [(19042, False), (19045, True)]:
141+
session = MockSession(is_ssl=is_ssl)
142+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
143+
for f in session.futures:
144+
f.result()
145+
assert len(pool._connections) == expected_connections
146+
for shard_id, connection in pool._connections.items():
147+
assert connection.features.shard_id == shard_id
148+
if shard_id == 0:
149+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
150+
else:
151+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
101152

102-
for port, is_ssl in [(19042, False), (19045, True)]:
103-
session = MockSession(is_ssl=is_ssl)
104-
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
105-
for f in session.futures:
106-
f.result()
107-
assert len(pool._connections) == 4
108-
for shard_id, connection in pool._connections.items():
109-
assert connection.features.shard_id == shard_id
110-
if shard_id == 0:
111-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112-
else:
113-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
114-
115-
session.cluster.executor.shutdown(wait=True)
153+
sleep_time = 0
154+
found_related_calls = 0
155+
for delay, fn, args, kwargs in session.scheduler_calls:
156+
if fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler:
157+
found_related_calls += 1
158+
self.assertEqual(delay, sleep_time)
159+
self.assertLessEqual(len(session.hosts) * (shard_count - 1), found_related_calls)
160+
finally:
161+
if session:
162+
session.cluster.scheduler.shutdown()
163+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)