Skip to content

Commit 2529fdf

Browse files
committed
Implement LimitedConcurrencyShardConnectionBackoffPolicy
This policy is an implementation of `ShardConnectionBackoffPolicy`. Its primary purpose is to prevent connection storms by imposing restrictions on the number of concurrent pending connections per host and backoff time between each connection attempt.
1 parent 8be8251 commit 2529fdf

File tree

3 files changed

+421
-8
lines changed

3 files changed

+421
-8
lines changed

cassandra/policies.py

Lines changed: 232 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
from __future__ import annotations
1515

1616
import random
17-
1817
from collections import namedtuple
18+
from functools import partial
1919
from itertools import islice, cycle, groupby, repeat
2020
import logging
2121
from random import randint, shuffle
2222
from threading import Lock
2323
import socket
2424
import warnings
25-
from typing import Callable, TYPE_CHECKING
25+
from typing import Callable, TYPE_CHECKING, Iterator, List, Tuple
2626
from abc import ABC, abstractmethod
2727
from cassandra import WriteType as WT
2828

@@ -997,6 +997,236 @@ def shutdown(self):
997997
self.is_shutdown = True
998998

999999

1000+
class ShardConnectionBackoffSchedule(ABC):
1001+
@abstractmethod
1002+
def new_schedule(self) -> Iterator[float]:
1003+
"""
1004+
This should return a finite or infinite iterable of delays (each as a
1005+
floating point number of seconds).
1006+
Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted.
1007+
"""
1008+
raise NotImplementedError()
1009+
1010+
1011+
class ConstantShardConnectionBackoffSchedule(ShardConnectionBackoffSchedule):
1012+
"""
1013+
A :class:`.ShardConnectionBackoffSchedule` subclass which introduce a constant delay with jitter
1014+
between shard connections.
1015+
"""
1016+
1017+
def __init__(self, delay, jitter: float = 0.0):
1018+
"""
1019+
`delay` should be a floating point number of seconds to wait in-between
1020+
each connection attempt.
1021+
1022+
`jitter` is a random jitter in seconds.
1023+
"""
1024+
if delay < 0:
1025+
raise ValueError("delay must not be negative")
1026+
if jitter < 0:
1027+
raise ValueError("jitter must not be negative")
1028+
1029+
self.delay = delay
1030+
self.jitter = jitter
1031+
1032+
def new_schedule(self):
1033+
if self.jitter == 0:
1034+
return repeat(self.delay)
1035+
def iterator():
1036+
while True:
1037+
yield self.delay + random.uniform(0.0, self.jitter)
1038+
return iterator()
1039+
1040+
1041+
class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
1042+
"""
1043+
A shard connection backoff policy that allows only `max_concurrent` concurrent connections per `host_id`.
1044+
1045+
For backoff calculation, it requires either a `cassandra.policies.ShardConnectionBackoffSchedule` or
1046+
a `cassandra.policies.ReconnectionPolicy`, as both expose the same API.
1047+
1048+
The backoff schedule is initiated when the first request for a given `host_id` is received.
1049+
If there are no remaining requests for that `host_id`, the schedule is reset.
1050+
1051+
This policy also prevents multiple pending or scheduled connections for the same (host, shard) pair;
1052+
any duplicate attempts to schedule a connection are silently ignored.
1053+
"""
1054+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy
1055+
1056+
max_concurrent: int
1057+
"""
1058+
Max concurrent connection creation requests per scope.
1059+
"""
1060+
1061+
def __init__(
1062+
self,
1063+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy,
1064+
max_concurrent: int = 1,
1065+
):
1066+
if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)):
1067+
raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy")
1068+
if max_concurrent < 1:
1069+
raise ValueError("max_concurrent must be a positive integer")
1070+
self.backoff_policy = backoff_policy
1071+
self.max_concurrent = max_concurrent
1072+
1073+
def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler:
1074+
return _LimitedConcurrencyShardConnectionScheduler(scheduler, self.backoff_policy, self.max_concurrent)
1075+
1076+
1077+
class _ScopeBucket:
1078+
"""
1079+
Holds information for a shard connection backoff policy scope, schedules and executes requests to create connection.
1080+
"""
1081+
session: _Scheduler
1082+
backoff_policy: ShardConnectionBackoffSchedule
1083+
lock: Lock
1084+
is_shutdown: bool = False
1085+
1086+
max_concurrent: int
1087+
"""
1088+
Max concurrent connection creation requests in the scope.
1089+
"""
1090+
1091+
currently_pending: int
1092+
"""
1093+
Number of currently pending connections.
1094+
"""
1095+
1096+
items: List[Callable[[], None]]
1097+
"""
1098+
List of scheduled create connections requests.
1099+
"""
1100+
1101+
def __init__(
1102+
self,
1103+
scheduler: _Scheduler,
1104+
backoff_policy: ShardConnectionBackoffSchedule,
1105+
max_concurrent: int,
1106+
):
1107+
self.items = []
1108+
self.scheduler = scheduler
1109+
self.backoff_policy = backoff_policy
1110+
self.lock = Lock()
1111+
self.max_concurrent = max_concurrent
1112+
self.currently_pending = 0
1113+
1114+
def _get_delay(self, schedule: Iterator[float]) -> Tuple[Iterator[float], float]:
1115+
try:
1116+
return schedule, next(schedule)
1117+
except StopIteration:
1118+
# A bit of trickery to avoid having lock around self.schedule
1119+
schedule = self.backoff_policy.new_schedule()
1120+
delay = next(schedule)
1121+
self.schedule = schedule
1122+
return schedule, delay
1123+
1124+
def _run(self, schedule: Iterator[float]):
1125+
if self.is_shutdown:
1126+
return
1127+
1128+
with self.lock:
1129+
try:
1130+
request = self.items.pop(0)
1131+
except IndexError:
1132+
# Just in case
1133+
if self.currently_pending > 0:
1134+
self.currently_pending -= 1
1135+
# When items are exhausted reset schedule to ensure that new items going to get another schedule
1136+
# It is important for exponential policy
1137+
return
1138+
1139+
try:
1140+
request()
1141+
finally:
1142+
schedule, delay = self._get_delay(schedule)
1143+
self.scheduler.schedule(delay, self._run, schedule)
1144+
1145+
def schedule_new_connection(self, cb: Callable[[], None]):
1146+
with self.lock:
1147+
if self.is_shutdown:
1148+
return
1149+
self.items.append(cb)
1150+
if self.currently_pending < self.max_concurrent:
1151+
self.currently_pending += 1
1152+
schedule = self.backoff_policy.new_schedule()
1153+
delay = next(schedule)
1154+
self.scheduler.schedule(delay, self._run, schedule)
1155+
1156+
def shutdown(self):
1157+
with self.lock:
1158+
self.is_shutdown = True
1159+
1160+
1161+
class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler):
1162+
"""
1163+
A scheduler for ``cassandra.policies.LimitedConcurrencyShardConnectionPolicy``.
1164+
1165+
Limits concurrency for connection creation requests to ``max_concurrent`` per host_id.
1166+
"""
1167+
1168+
already_scheduled: set[tuple[str, int]]
1169+
"""
1170+
Set of (host_id, shard_id) of scheduled or pending requests.
1171+
"""
1172+
1173+
per_host_scope: dict[str, _ScopeBucket]
1174+
"""
1175+
Scopes storage, key is host_id, value is an instance that holds scope data.
1176+
"""
1177+
1178+
backoff_policy: ShardConnectionBackoffSchedule
1179+
scheduler: _Scheduler
1180+
lock: Lock
1181+
is_shutdown: bool = False
1182+
1183+
max_concurrent: int
1184+
"""
1185+
Max concurrent connection creation requests per host_id.
1186+
"""
1187+
1188+
def __init__(
1189+
self,
1190+
scheduler: _Scheduler,
1191+
backoff_policy: ShardConnectionBackoffSchedule,
1192+
max_concurrent: int,
1193+
):
1194+
self.already_scheduled = set()
1195+
self.per_host_scope = {}
1196+
self.backoff_policy = backoff_policy
1197+
self.max_concurrent = max_concurrent
1198+
self.scheduler = scheduler
1199+
self.lock = Lock()
1200+
1201+
def _execute(self, host_id: str, shard_id: int, method: Callable[[], None]):
1202+
if self.is_shutdown:
1203+
return
1204+
try:
1205+
method()
1206+
finally:
1207+
with self.lock:
1208+
self.already_scheduled.remove((host_id, shard_id))
1209+
1210+
def schedule(self, host_id: str, shard_id: int, method: Callable[[], None]) -> bool:
1211+
with self.lock:
1212+
if self.is_shutdown or (host_id, shard_id) in self.already_scheduled:
1213+
return False
1214+
self.already_scheduled.add((host_id, shard_id))
1215+
1216+
scope_info = self.per_host_scope.get(host_id)
1217+
if not scope_info:
1218+
scope_info = _ScopeBucket(self.scheduler, self.backoff_policy, self.max_concurrent)
1219+
self.per_host_scope[host_id] = scope_info
1220+
scope_info.schedule_new_connection(partial(self._execute, host_id, shard_id, method))
1221+
return True
1222+
1223+
def shutdown(self):
1224+
with self.lock:
1225+
self.is_shutdown = True
1226+
for scope in self.per_host_scope.values():
1227+
scope.shutdown()
1228+
1229+
10001230
class RetryPolicy(object):
10011231
"""
10021232
A policy that describes whether to retry, rethrow, or ignore coordinator

0 commit comments

Comments
 (0)