Skip to content

Commit 20ccb08

Browse files
committed
feat(policy): add shard reconnection policies
Add abstract classes: `ShardReconnectionPolicy` and `ShardReconnectionScheduler` And implementations: `NoDelayShardReconnectionPolicy` - policy that represents old behavior of having no delay and no concurrency restriction. `NoConcurrentShardReconnectionPolicy` - policy that limits concurrent reconnections to 1 per scope and introduces delay between reconnections within the scope.
1 parent d5834c6 commit 20ccb08

File tree

2 files changed

+461
-2
lines changed

2 files changed

+461
-2
lines changed

cassandra/policies.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
import threading
16+
import time
17+
import weakref
18+
from abc import ABC, abstractmethod
1519

1620
from collections import namedtuple
21+
from enum import Enum
1722
from functools import lru_cache
1823
from itertools import islice, cycle, groupby, repeat
1924
import logging
2025
from random import randint, shuffle
2126
from threading import Lock
2227
import socket
2328
import warnings
29+
from typing import TYPE_CHECKING, Callable, Any, List
2430

2531
log = logging.getLogger(__name__)
2632

2733
from cassandra import WriteType as WT
2834

35+
if TYPE_CHECKING:
36+
from cluster import Session
2937

3038
# This is done this way because WriteType was originally
3139
# defined here and in order not to break the API.
@@ -864,6 +872,210 @@ def _add_jitter(self, value):
864872
return min(max(self.base_delay, delay), self.max_delay)
865873

866874

875+
class ShardReconnectionScheduler(ABC):
876+
@abstractmethod
877+
def schedule(
878+
self,
879+
host_id: str,
880+
shard_id: int,
881+
method: Callable[..., Any],
882+
*args: List[...],
883+
**kwargs: dict[Any, Any]) -> None:
884+
raise NotImplementedError()
885+
886+
887+
class ShardReconnectionPolicy(ABC):
888+
"""
889+
Base class for shard reconnection policies.
890+
891+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
892+
"""
893+
894+
@abstractmethod
895+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
896+
raise NotImplementedError()
897+
898+
899+
class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy):
900+
"""
901+
A shard reconnection policy with no delay between attempts and no concurrency restrictions.
902+
Ensures at most one pending reconnection per (host, shard) pair — any additional
903+
reconnection attempts for the same (host, shard) are silently ignored.
904+
905+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
906+
"""
907+
908+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
909+
return _NoDelayShardReconnectionScheduler(session)
910+
911+
912+
class _NoDelayShardReconnectionScheduler(ShardReconnectionScheduler):
913+
def __init__(self, session: Session):
914+
self.session: Session = weakref.proxy(session)
915+
self.already_scheduled: dict[str, bool] = {}
916+
917+
def _execute(
918+
self,
919+
scheduled_key: str,
920+
method: Callable[..., Any],
921+
*args: List[...],
922+
**kwargs: dict[Any, Any]) -> None:
923+
try:
924+
method(*args, **kwargs)
925+
finally:
926+
self.already_scheduled[scheduled_key] = False
927+
928+
def schedule(
929+
self,
930+
host_id: str,
931+
shard_id: int,
932+
method: Callable[..., Any],
933+
*args: List[...],
934+
**kwargs: dict[Any, Any]) -> None:
935+
scheduled_key = f'{host_id}-{shard_id}'
936+
if self.already_scheduled.get(scheduled_key):
937+
return
938+
939+
self.already_scheduled[scheduled_key] = True
940+
if not self.session.is_shutdown:
941+
self.session.submit(self._execute, scheduled_key, method, *args, **kwargs)
942+
943+
944+
class ShardReconnectionPolicyScope(Enum):
945+
"""
946+
A scope for `ShardReconnectionPolicy`, in particular `NoConcurrentShardReconnectionPolicy`
947+
"""
948+
Cluster = 0
949+
Host = 1
950+
951+
952+
class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy):
953+
"""
954+
A shard reconnection policy that allows only one pending connection per scope, where scope could be `Host`, `Cluster`
955+
For backoff it uses `ReconnectionPolicy`, when there is no more reconnections to scheduled backoff policy is reminded
956+
For all scopes does not allow schedule multiple reconnections for same host+shard, it silently ignores attempts to do that.
957+
958+
On `new_scheduler` instantiate a scheduler that behaves according to the policy
959+
"""
960+
961+
def __init__(
962+
self,
963+
shard_reconnection_scope: ShardReconnectionPolicyScope,
964+
reconnection_policy: ReconnectionPolicy,
965+
):
966+
if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope):
967+
raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope")
968+
if not isinstance(reconnection_policy, ReconnectionPolicy):
969+
raise ValueError("reconnection_policy must be a ReconnectionPolicy")
970+
self.shard_reconnection_scope = shard_reconnection_scope
971+
self.reconnection_policy = reconnection_policy
972+
973+
def new_scheduler(self, session: Session) -> ShardReconnectionScheduler:
974+
return _NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy)
975+
976+
977+
class _ScopeBucket:
978+
"""
979+
Holds information for a shard reconnection scope, schedules and executes reconnections.
980+
"""
981+
982+
def __init__(
983+
self,
984+
session: Session,
985+
reconnection_policy: ReconnectionPolicy,
986+
):
987+
self._items = []
988+
self.session = session
989+
self.reconnection_policy = reconnection_policy
990+
self.lock = threading.Lock()
991+
self.running = False
992+
self.schedule = self.reconnection_policy.new_schedule()
993+
994+
def _get_delay(self):
995+
if self.schedule is None:
996+
self.schedule = self.reconnection_policy.new_schedule()
997+
try:
998+
return next(self.schedule)
999+
except StopIteration:
1000+
self.schedule = self.reconnection_policy.new_schedule()
1001+
return next(self.schedule)
1002+
1003+
def _schedule(self):
1004+
if self.session.is_shutdown:
1005+
return
1006+
delay = self._get_delay()
1007+
if delay:
1008+
self.session.cluster.scheduler.schedule(delay, self._run)
1009+
else:
1010+
self.session.submit(self._run)
1011+
1012+
def _run(self):
1013+
if self.session.is_shutdown:
1014+
return
1015+
1016+
with self.lock:
1017+
try:
1018+
item = self._items.pop()
1019+
except IndexError:
1020+
self.running = False
1021+
self.schedule = None
1022+
return
1023+
1024+
method, args, kwargs = item
1025+
try:
1026+
method(*args, **kwargs)
1027+
finally:
1028+
self._schedule()
1029+
1030+
def add(self, method, *args, **kwargs):
1031+
with self.lock:
1032+
self._items.append([method, args, kwargs])
1033+
if not self.running:
1034+
self.running = True
1035+
self._schedule()
1036+
1037+
1038+
class _NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler):
1039+
def __init__(
1040+
self,
1041+
session: Session,
1042+
shard_reconnection_scope: ShardReconnectionPolicyScope,
1043+
reconnection_policy: ReconnectionPolicy,
1044+
):
1045+
self.already_scheduled: dict[str, bool] = {}
1046+
self.scopes: dict[str, _ScopeBucket] = {}
1047+
self.shard_reconnection_scope = shard_reconnection_scope
1048+
self.reconnection_policy = reconnection_policy
1049+
self.session = session
1050+
self.lock = threading.Lock()
1051+
1052+
def _execute(self, scheduled_key, method, *args, **kwargs):
1053+
try:
1054+
method(*args, **kwargs)
1055+
finally:
1056+
with self.lock:
1057+
self.already_scheduled[scheduled_key] = False
1058+
1059+
def schedule(self, host_id, shard_id, method, *args, **kwargs):
1060+
if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster:
1061+
scope_hash = "global-cluster-scope"
1062+
else:
1063+
scope_hash = host_id
1064+
scheduled_key = f'{host_id}-{shard_id}'
1065+
1066+
with self.lock:
1067+
if self.already_scheduled.get(scheduled_key):
1068+
return False
1069+
self.already_scheduled[scheduled_key] = True
1070+
1071+
scope_info = self.scopes.get(scope_hash, 0)
1072+
if not scope_info:
1073+
scope_info = _ScopeBucket(self.session, self.reconnection_policy)
1074+
self.scopes[scope_hash] = scope_info
1075+
scope_info.add(self._execute, scheduled_key, method, *args, **kwargs)
1076+
return True
1077+
1078+
8671079
class RetryPolicy(object):
8681080
"""
8691081
A policy that describes whether to retry, rethrow, or ignore coordinator

0 commit comments

Comments
 (0)