|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import unittest |
| 16 | +from functools import partial |
16 | 17 |
|
17 | 18 | from itertools import islice, cycle |
18 | 19 | from unittest.mock import Mock, patch, call |
|
26 | 27 | from cassandra import ConsistencyLevel |
27 | 28 | from cassandra.cluster import Cluster, ControlConnection |
28 | 29 | from cassandra.metadata import Metadata |
29 | | -from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, |
| 30 | +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, |
| 31 | + DCAwareRoundRobinPolicy, |
30 | 32 | TokenAwarePolicy, SimpleConvictionPolicy, |
31 | 33 | HostDistance, ExponentialReconnectionPolicy, |
32 | 34 | RetryPolicy, WriteType, |
33 | 35 | DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, |
34 | 36 | LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, |
35 | | - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) |
| 37 | + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, |
| 38 | + ExponentialBackoffRetryPolicy, _NoDelayShardConnectionBackoffScheduler) |
36 | 39 | from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint |
37 | 40 | from cassandra.pool import Host |
38 | 41 | from cassandra.query import Statement |
@@ -1579,3 +1582,79 @@ def test_create_whitelist(self): |
1579 | 1582 | # Only the filtered replicas should be allowed |
1580 | 1583 | self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), |
1581 | 1584 | Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) |
| 1585 | + |
| 1586 | + |
| 1587 | +class MockScheduler: |
| 1588 | + def __init__(self): |
| 1589 | + self.requests = [] |
| 1590 | + |
| 1591 | + def schedule(self, delay, fn, *args, **kwargs): |
| 1592 | + self.requests.append((delay, fn, args, kwargs)) |
| 1593 | + |
| 1594 | + def execute(self): |
| 1595 | + for delay, fn, args, kwargs in self.requests: |
| 1596 | + fn(*args, **kwargs) |
| 1597 | + self.requests = [] |
| 1598 | + |
| 1599 | + |
| 1600 | +class NoDelayShardConnectionBackoffSchedulerTests(unittest.TestCase): |
| 1601 | + def test_schedule_executes_method_immediately(self): |
| 1602 | + method = Mock() |
| 1603 | + scheduler = MockScheduler() |
| 1604 | + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) |
| 1605 | + |
| 1606 | + self.assertTrue(policy.schedule('host1', 0, partial(method, 1, 2, key='val'))) |
| 1607 | + |
| 1608 | + self.assertEqual(scheduler.requests[0][0], 0) |
| 1609 | + scheduler.execute() |
| 1610 | + |
| 1611 | + method.assert_called_once_with(1, 2, key='val') |
| 1612 | + |
| 1613 | + def test_schedule_skips_if_host_shard_already_scheduled(self): |
| 1614 | + method = Mock() |
| 1615 | + scheduler = MockScheduler() |
| 1616 | + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) |
| 1617 | + |
| 1618 | + self.assertTrue(policy.schedule('host1', 0, method)) |
| 1619 | + self.assertFalse(policy.schedule('host1', 0, method)) |
| 1620 | + |
| 1621 | + self.assertEqual(len(scheduler.requests), 1) |
| 1622 | + scheduler.execute() |
| 1623 | + method.assert_called_once() |
| 1624 | + |
| 1625 | + def test_schedule_does_not_skip_if_shard_is_different(self): |
| 1626 | + method = Mock() |
| 1627 | + scheduler = MockScheduler() |
| 1628 | + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) |
| 1629 | + |
| 1630 | + self.assertTrue(policy.schedule('host1', 0, method)) |
| 1631 | + self.assertTrue(policy.schedule('host1', 1, method)) |
| 1632 | + |
| 1633 | + self.assertEqual(len(scheduler.requests), 2) |
| 1634 | + scheduler.execute() |
| 1635 | + |
| 1636 | + self.assertEqual(method.call_count, 2) |
| 1637 | + |
| 1638 | + def test_already_scheduled_resets_after_execution(self): |
| 1639 | + method = Mock() |
| 1640 | + scheduler = MockScheduler() |
| 1641 | + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) |
| 1642 | + self.assertTrue(policy.schedule('host1', 0, method)) |
| 1643 | + |
| 1644 | + scheduler.execute() |
| 1645 | + |
| 1646 | + self.assertTrue(policy.schedule('host1', 0, method)) |
| 1647 | + |
| 1648 | + scheduler.execute() |
| 1649 | + |
| 1650 | + self.assertEqual(method.call_count, 2) |
| 1651 | + |
| 1652 | + def test_schedule_skips_if_shutdown(self): |
| 1653 | + method = Mock() |
| 1654 | + scheduler = MockScheduler() |
| 1655 | + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) |
| 1656 | + policy.shutdown() |
| 1657 | + |
| 1658 | + policy.schedule('host1', 0, method) |
| 1659 | + |
| 1660 | + self.assertEqual(len(scheduler.requests), 0) |
0 commit comments