Skip to content

Commit 51f7986

Browse files
committed
1
1 parent d5834c6 commit 51f7986

File tree

6 files changed

+258
-43
lines changed

6 files changed

+258
-43
lines changed

cassandra/cluster.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
ExponentialReconnectionPolicy, HostDistance,
7373
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7474
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
75-
NeverRetryPolicy)
75+
NeverRetryPolicy, ConstantReconnectionPolicy,
76+
ShardReconnectionPolicyScope, ShardReconnectionPolicy, NoConcurrentShardReconnectionPolicy)
7677
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7778
HostConnectionPool, HostConnection,
7879
NoConnectionsAvailable)
@@ -742,6 +743,19 @@ def auth_provider(self, value):
742743

743744
self._auth_provider = value
744745

746+
_shard_reconnection_policy = None
747+
@property
748+
def shard_reconnection_policy(self):
749+
return self._shard_reconnection_policy
750+
751+
@shard_reconnection_policy.setter
752+
def shard_reconnection_policy(self, srp):
753+
if self._config_mode == _ConfigMode.PROFILES:
754+
raise ValueError(
755+
"Cannot set Cluster.shard_reconnection_policy while using Configuration Profiles. Set this in a profile instead.")
756+
self._shard_reconnection_policy = srp
757+
self._config_mode = _ConfigMode.LEGACY
758+
745759
_load_balancing_policy = None
746760
@property
747761
def load_balancing_policy(self):
@@ -1204,6 +1218,7 @@ def __init__(self,
12041218
shard_aware_options=None,
12051219
metadata_request_timeout=None,
12061220
column_encryption_policy=None,
1221+
shard_reconnection_policy=None,
12071222
):
12081223
"""
12091224
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1309,6 +1324,17 @@ def __init__(self,
13091324
else:
13101325
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13111326

1327+
if shard_reconnection_policy is not None:
1328+
if isinstance(shard_reconnection_policy, type):
1329+
raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class")
1330+
if not isinstance(shard_reconnection_policy, ShardReconnectionPolicy):
1331+
raise TypeError("load_balancing_policy should be an instance of class derived from ReconnectionPolicy")
1332+
self.shard_reconnection_policy = shard_reconnection_policy
1333+
else:
1334+
self._shard_reconnection_policy = NoConcurrentShardReconnectionPolicy(
1335+
ShardReconnectionPolicyScope.Host,
1336+
ConstantReconnectionPolicy(2, 0)) # set internal attribute to avoid committing to legacy config mode
1337+
13121338
if reconnection_policy is not None:
13131339
if isinstance(reconnection_policy, type):
13141340
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2707,6 +2733,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27072733
self._protocol_version = self.cluster.protocol_version
27082734

27092735
self.encoder = Encoder()
2736+
self.shard_reconnection_scheduler = cluster.shard_reconnection_policy.new_scheduler(self)
27102737

27112738
# create connection pools in parallel
27122739
self._initial_connect_futures = set()
@@ -4432,6 +4459,9 @@ def shutdown(self):
44324459
self._queue.put_nowait((0, 0, None))
44334460
self.join()
44344461

4462+
def empty(self):
4463+
return len(self._scheduled_tasks) == 0 and self._queue.empty()
4464+
44354465
def schedule(self, delay, fn, *args, **kwargs):
44364466
self._insert_task(delay, (fn, args, tuple(kwargs.items())))
44374467

cassandra/policies.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
import threading
16+
import time
17+
import weakref
1518

1619
from collections import namedtuple
20+
from enum import Enum
1721
from functools import lru_cache
1822
from itertools import islice, cycle, groupby, repeat
1923
import logging
@@ -778,6 +782,14 @@ def new_schedule(self):
778782
raise NotImplementedError()
779783

780784

785+
class NoDelayReconnectionPolicy(ReconnectionPolicy):
786+
"""
787+
A :class:`.ReconnectionPolicy` subclass which does not sleep.
788+
"""
789+
def new_schedule(self):
790+
return repeat(0)
791+
792+
781793
class ConstantReconnectionPolicy(ReconnectionPolicy):
782794
"""
783795
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
@@ -864,6 +876,146 @@ def _add_jitter(self, value):
864876
return min(max(self.base_delay, delay), self.max_delay)
865877

866878

879+
class ShardReconnectionScheduler(object):
880+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
881+
raise NotImplementedError()
882+
883+
class ShardReconnectionPolicy(object):
884+
def new_scheduler(self, session) -> ShardReconnectionScheduler:
885+
raise NotImplementedError()
886+
887+
888+
class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy):
889+
def new_scheduler(self, session) -> ShardReconnectionScheduler:
890+
return NoDelayShardReconnectionScheduler(session)
891+
892+
893+
class NoDelayShardReconnectionScheduler(ShardReconnectionScheduler):
894+
def __init__(self, session):
895+
self.session = weakref.proxy(session)
896+
self.already_scheduled = {}
897+
898+
def _execute(self, scheduled_key, method, *args, **kwargs):
899+
try:
900+
method(*args, **kwargs)
901+
finally:
902+
self.already_scheduled[scheduled_key] = False
903+
904+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
905+
scheduled_key = f'{host_id}-{shard_id}'
906+
if self.already_scheduled.get(scheduled_key):
907+
return
908+
909+
self.already_scheduled[scheduled_key] = True
910+
if not self.session.is_shutdown:
911+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
912+
913+
914+
class ShardReconnectionPolicyScope(Enum):
915+
Cluster = 0
916+
Host = 1
917+
918+
919+
class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy):
920+
def __init__(self, shard_reconnection_scope, reconnection_policy):
921+
if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope):
922+
raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope")
923+
if not isinstance(reconnection_policy, ReconnectionPolicy):
924+
raise ValueError("reconnection_policy must be a ReconnectionPolicy")
925+
self.shard_reconnection_scope = shard_reconnection_scope
926+
self.reconnection_policy = reconnection_policy
927+
928+
def new_scheduler(self, session) -> ShardReconnectionScheduler:
929+
return NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy)
930+
931+
932+
class _ScopeBucket(object):
933+
def __init__(self, session, shard_reconnection_policy):
934+
self._items = []
935+
self.last_run = None
936+
self.session = session
937+
self.policy = shard_reconnection_policy
938+
self.lock = threading.Lock()
939+
self.running = False
940+
self.schedule = self.policy.new_schedule()
941+
942+
def add(self, method, *args, **kwargs):
943+
with self.lock:
944+
self._items.append([method, args, kwargs])
945+
if not self.running:
946+
self.running = True
947+
self._schedule()
948+
949+
def _get_delay(self):
950+
try:
951+
return next(self.schedule)
952+
except StopIteration:
953+
self.schedule = self.policy.new_schedule()
954+
return next(self.schedule)
955+
956+
def _schedule(self):
957+
if self.session.is_shutdown:
958+
return
959+
delay = self._get_delay()
960+
if delay:
961+
self.session.cluster.scheduler.schedule(delay, self.run)
962+
else:
963+
self.session.submit(self.run)
964+
965+
def run(self):
966+
if self.session.is_shutdown:
967+
return
968+
969+
with self.lock:
970+
try:
971+
item = self._items.pop()
972+
except IndexError:
973+
self.running = False
974+
return
975+
976+
method, args, kwargs = item
977+
try:
978+
method(*args, **kwargs)
979+
finally:
980+
self._schedule()
981+
982+
983+
class NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler):
984+
def __init__(self, session, shard_reconnection_scope, reconnection_policy):
985+
self.already_scheduled = {}
986+
self.scopes = {}
987+
self.shard_reconnection_scope = shard_reconnection_scope
988+
self.reconnection_policy = reconnection_policy
989+
self.session = session
990+
self.lock = threading.Lock()
991+
992+
def _execute(self, scheduled_key, method, *args, **kwargs):
993+
try:
994+
method(*args, **kwargs)
995+
finally:
996+
with self.lock:
997+
self.already_scheduled[scheduled_key] = False
998+
999+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
1000+
if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster:
1001+
scope_hash = "global-cluster-scope"
1002+
else:
1003+
scope_hash = host_id
1004+
scheduled_key = f'{host_id}-{shard_id}'
1005+
1006+
with self.lock:
1007+
if self.already_scheduled.get(scheduled_key):
1008+
return False
1009+
self.already_scheduled[scheduled_key] = True
1010+
1011+
scope_info = self.scopes.get(scope_hash, 0)
1012+
if not scope_info:
1013+
scope_info = _ScopeBucket(self.session, self.reconnection_policy)
1014+
self.scopes[scope_hash] = scope_info
1015+
scope_info.add(self._execute, scheduled_key, method,*args, **kwargs)
1016+
return True
1017+
1018+
8671019
class RetryPolicy(object):
8681020
"""
8691021
A policy that describes whether to retry, rethrow, or ignore coordinator

cassandra/pool.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_reconnection_scheduler.schedule(
489+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_reconnection_scheduler.schedule(
500+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_reconnection_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, self._open_connection_to_missing_shard, connection.features.shard_id)
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_reconnection_scheduler.schedule(
854+
self.host.host_id, shard_id, self._open_connection_to_missing_shard, shard_id)
863855

864856
trash_conns = None
865857
with self._lock:

tests/integration/standard/test_shard_aware.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import pytest
3030

3131
from cassandra.cluster import Cluster
32-
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy
32+
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy, \
33+
NoDelayShardReconnectionPolicy
3334
from cassandra import OperationTimedOut, ConsistencyLevel
3435

3536
from tests.integration import use_cluster, get_node, PROTOCOL_VERSION
@@ -47,6 +48,7 @@ class TestShardAwareIntegration(unittest.TestCase):
4748
def setup_class(cls):
4849
cls.cluster = Cluster(contact_points=["127.0.0.1"], protocol_version=PROTOCOL_VERSION,
4950
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
51+
shard_reconnection_policy=NoDelayShardReconnectionPolicy(),
5052
reconnection_policy=ConstantReconnectionPolicy(1))
5153
cls.session = cls.cluster.connect()
5254
LOGGER.info(cls.cluster.is_shard_aware())

tests/unit/test_host_connection_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from threading import Thread, Event, Lock
2323
from unittest.mock import Mock, NonCallableMagicMock, MagicMock
2424

25-
from cassandra.cluster import Session, ShardAwareOptions
25+
from cassandra.cluster import Session, ShardAwareOptions, NoDelayShardReconnectionScheduler
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
@@ -41,6 +41,8 @@ def make_session(self):
4141
session.cluster.get_core_connections_per_host.return_value = 1
4242
session.cluster.get_max_requests_per_connection.return_value = 1
4343
session.cluster.get_max_connections_per_host.return_value = 1
44+
session.shard_reconnection_scheduler = NoDelayShardReconnectionScheduler(session)
45+
session.is_shutdown = False
4446
return session
4547

4648
def test_borrow_and_return(self):

0 commit comments

Comments
 (0)