Skip to content

Commit c0f432c

Browse files
committed
Integrate ShardConnectionBackoffPolicy
Add code that integrates ShardConnectionBackoffPolicy into: 1. Cluster 2. Session 3. HostConnection Main idea is to put ShardConnectionBackoffPolicy in control of shard connection creation proccess. Removing duplicate logic from HostConnection that tracks pending connection creation requests.
1 parent c037f47 commit c0f432c

File tree

3 files changed

+47
-27
lines changed

3 files changed

+47
-27
lines changed

cassandra/cluster.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
ExponentialReconnectionPolicy, HostDistance,
7474
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7575
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
76-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy,
77+
ShardConnectionScheduler)
7778
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7879
HostConnectionPool, HostConnection,
7980
NoConnectionsAvailable)
@@ -757,6 +758,11 @@ def auth_provider(self, value):
757758

758759
self._auth_provider = value
759760

761+
_shard_connection_backoff_policy: ShardConnectionBackoffPolicy
762+
@property
763+
def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy:
764+
return self._shard_connection_backoff_policy
765+
760766
_load_balancing_policy = None
761767
@property
762768
def load_balancing_policy(self):
@@ -1219,7 +1225,8 @@ def __init__(self,
12191225
shard_aware_options=None,
12201226
metadata_request_timeout=None,
12211227
column_encryption_policy=None,
1222-
application_info:Optional[ApplicationInfoBase]=None
1228+
application_info: Optional[ApplicationInfoBase] = None,
1229+
shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None,
12231230
):
12241231
"""
12251232
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1325,6 +1332,13 @@ def __init__(self,
13251332
else:
13261333
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13271334

1335+
if shard_connection_backoff_policy is not None:
1336+
if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy):
1337+
raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy")
1338+
self._shard_connection_backoff_policy = shard_connection_backoff_policy
1339+
else:
1340+
self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy()
1341+
13281342
if reconnection_policy is not None:
13291343
if isinstance(reconnection_policy, type):
13301344
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl):
27162730
_metrics = None
27172731
_request_init_callbacks = None
27182732
_graph_paging_available = False
2733+
shard_connection_backoff_scheduler: ShardConnectionScheduler
27192734

27202735
def __init__(self, cluster, hosts, keyspace=None):
27212736
self.cluster = cluster
@@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27302745
self._protocol_version = self.cluster.protocol_version
27312746

27322747
self.encoder = Encoder()
2748+
self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_connection_scheduler(self.cluster.scheduler)
27332749

27342750
# create connection pools in parallel
27352751
self._initial_connect_futures = set()
@@ -3340,6 +3356,7 @@ def shutdown(self):
33403356
else:
33413357
self.is_shutdown = True
33423358

3359+
self.shard_connection_backoff_scheduler.shutdown()
33433360
# PYTHON-673. If shutdown was called shortly after session init, avoid
33443361
# a race by cancelling any initial connection attempts haven't started,
33453362
# then blocking on any that have.

cassandra/pool.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Connection pooling and host management.
1717
"""
1818
from concurrent.futures import Future
19-
from functools import total_ordering
19+
from functools import total_ordering, partial
2020
import logging
2121
import socket
2222
import time
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_connection_backoff_scheduler.schedule(
489+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_connection_backoff_scheduler.schedule(
500+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_connection_backoff_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, partial(self._open_connection_to_missing_shard, connection.features.shard_id))
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_connection_backoff_scheduler.schedule(
854+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
863855

864856
trash_conns = None
865857
with self._lock:

tests/unit/test_host_connection_pool.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,23 @@
2222
from threading import Thread, Event, Lock
2323
from unittest.mock import Mock, NonCallableMagicMock, MagicMock
2424

25-
from cassandra.cluster import Session, ShardAwareOptions
25+
from cassandra.cluster import Session, ShardAwareOptions, _Scheduler
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

3333

34+
class FakeScheduler(_Scheduler):
35+
def __init__(self):
36+
super(FakeScheduler, self).__init__(ThreadPoolExecutor())
37+
38+
def schedule(self, delay, fn, *args, **kwargs):
39+
super().schedule(0, fn, *args, **kwargs)
40+
41+
3442
class _PoolTests(unittest.TestCase):
3543
__test__ = False
3644
PoolImpl = None
@@ -41,6 +49,9 @@ def make_session(self):
4149
session.cluster.get_core_connections_per_host.return_value = 1
4250
session.cluster.get_max_requests_per_connection.return_value = 1
4351
session.cluster.get_max_connections_per_host.return_value = 1
52+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(FakeScheduler())
53+
session.shard_connection_backoff_scheduler.schedule = Mock(wraps=session.shard_connection_backoff_scheduler.schedule)
54+
session.is_shutdown = False
4455
return session
4556

4657
def test_borrow_and_return(self):
@@ -174,9 +185,9 @@ def test_return_defunct_connection_on_down_host(self):
174185
if self.PoolImpl is HostConnection:
175186
# on shard aware implementation we use submit function regardless
176187
self.assertTrue(host.signal_connection_failure.call_args)
177-
self.assertTrue(session.submit.called)
188+
self.assertTrue(session.shard_connection_backoff_scheduler.schedule.called)
178189
else:
179-
self.assertFalse(session.submit.called)
190+
self.assertFalse(session.shard_connection_backoff_scheduler.schedule.called)
180191
self.assertTrue(session.cluster.signal_connection_failure.call_args)
181192
self.assertTrue(pool.is_shutdown)
182193

0 commit comments

Comments
 (0)