Skip to content

Commit d1f19fb

Browse files
committed
Introduce NoDelayShardConnectionBackoffPolicy
This policy is implementation of ShardConnectionBackoffPolicy. It implements same behavior that driver currently has: 1. No delay between creating shard connections 2. It avoids creating multiple connections to same host_id, shard_id
1 parent 3d8a7c1 commit d1f19fb

File tree

2 files changed

+144
-2
lines changed

2 files changed

+144
-2
lines changed

cassandra/policies.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,69 @@ def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionSche
936936
raise NotImplementedError()
937937

938938

939+
class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
940+
"""
941+
A shard connection backoff policy with no delay between attempts.
942+
Ensures that at most one pending request connection per (host, shard) pair.
943+
If connection attempts for the same (host, shard) it is silently dropped.
944+
"""
945+
946+
def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler:
947+
return _NoDelayShardConnectionBackoffScheduler(scheduler)
948+
949+
950+
class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler):
951+
"""
952+
A scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``.
953+
954+
A shard connection backoff policy with no delay between attempts.
955+
Ensures that at most one pending request connection per (host, shard) pair.
956+
If connection attempts for the same (host, shard) it is silently dropped.
957+
"""
958+
959+
scheduler: _Scheduler
960+
already_scheduled: set[tuple[str, int]]
961+
lock: Lock
962+
is_shutdown: bool = False
963+
964+
def __init__(self, scheduler: _Scheduler):
965+
self.scheduler = scheduler
966+
self.already_scheduled = set()
967+
self.lock = Lock()
968+
969+
def _execute(
970+
self,
971+
host_id: str,
972+
shard_id: int,
973+
method: Callable[[], None],
974+
) -> None:
975+
if self.is_shutdown:
976+
return
977+
try:
978+
method()
979+
finally:
980+
with self.lock:
981+
self.already_scheduled.remove((host_id, shard_id))
982+
983+
def schedule(
984+
self,
985+
host_id: str,
986+
shard_id: int,
987+
method: Callable[[], None],
988+
) -> bool:
989+
with self.lock:
990+
if self.is_shutdown or (host_id, shard_id) in self.already_scheduled:
991+
return False
992+
self.already_scheduled.add((host_id, shard_id))
993+
994+
self.scheduler.schedule(0, self._execute, host_id, shard_id, method)
995+
return True
996+
997+
def shutdown(self):
998+
with self.lock:
999+
self.is_shutdown = True
1000+
1001+
9391002
class RetryPolicy(object):
9401003
"""
9411004
A policy that describes whether to retry, rethrow, or ignore coordinator

tests/unit/test_policies.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from functools import partial
1617

1718
from itertools import islice, cycle
1819
from unittest.mock import Mock, patch, call
@@ -26,13 +27,15 @@
2627
from cassandra import ConsistencyLevel
2728
from cassandra.cluster import Cluster, ControlConnection
2829
from cassandra.metadata import Metadata
29-
from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy,
30+
from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy,
31+
DCAwareRoundRobinPolicy,
3032
TokenAwarePolicy, SimpleConvictionPolicy,
3133
HostDistance, ExponentialReconnectionPolicy,
3234
RetryPolicy, WriteType,
3335
DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
3436
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
35-
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy)
37+
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy,
38+
ExponentialBackoffRetryPolicy, _NoDelayShardConnectionBackoffScheduler)
3639
from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint
3740
from cassandra.pool import Host
3841
from cassandra.query import Statement
@@ -1579,3 +1582,79 @@ def test_create_whitelist(self):
15791582
# Only the filtered replicas should be allowed
15801583
self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy),
15811584
Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)})
1585+
1586+
1587+
class MockScheduler:
1588+
def __init__(self):
1589+
self.requests = []
1590+
1591+
def schedule(self, delay, fn, *args, **kwargs):
1592+
self.requests.append((delay, fn, args, kwargs))
1593+
1594+
def execute(self):
1595+
for delay, fn, args, kwargs in self.requests:
1596+
fn(*args, **kwargs)
1597+
self.requests = []
1598+
1599+
1600+
class NoDelayShardConnectionBackoffSchedulerTests(unittest.TestCase):
1601+
def test_schedule_executes_method_immediately(self):
1602+
method = Mock()
1603+
scheduler = MockScheduler()
1604+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1605+
1606+
self.assertTrue(policy.schedule('host1', 0, partial(method, 1, 2, key='val')))
1607+
1608+
self.assertEqual(scheduler.requests[0][0], 0)
1609+
scheduler.execute()
1610+
1611+
method.assert_called_once_with(1, 2, key='val')
1612+
1613+
def test_schedule_skips_if_host_shard_already_scheduled(self):
1614+
method = Mock()
1615+
scheduler = MockScheduler()
1616+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1617+
1618+
self.assertTrue(policy.schedule('host1', 0, method))
1619+
self.assertFalse(policy.schedule('host1', 0, method))
1620+
1621+
self.assertEqual(len(scheduler.requests), 1)
1622+
scheduler.execute()
1623+
method.assert_called_once()
1624+
1625+
def test_schedule_does_not_skip_if_shard_is_different(self):
1626+
method = Mock()
1627+
scheduler = MockScheduler()
1628+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1629+
1630+
self.assertTrue(policy.schedule('host1', 0, method))
1631+
self.assertTrue(policy.schedule('host1', 1, method))
1632+
1633+
self.assertEqual(len(scheduler.requests), 2)
1634+
scheduler.execute()
1635+
1636+
self.assertEqual(method.call_count, 2)
1637+
1638+
def test_already_scheduled_resets_after_execution(self):
1639+
method = Mock()
1640+
scheduler = MockScheduler()
1641+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1642+
self.assertTrue(policy.schedule('host1', 0, method))
1643+
1644+
scheduler.execute()
1645+
1646+
self.assertTrue(policy.schedule('host1', 0, method))
1647+
1648+
scheduler.execute()
1649+
1650+
self.assertEqual(method.call_count, 2)
1651+
1652+
def test_schedule_skips_if_shutdown(self):
1653+
method = Mock()
1654+
scheduler = MockScheduler()
1655+
policy = _NoDelayShardConnectionBackoffScheduler(scheduler)
1656+
policy.shutdown()
1657+
1658+
policy.schedule('host1', 0, method)
1659+
1660+
self.assertEqual(len(scheduler.requests), 0)

0 commit comments

Comments
 (0)