Skip to content

Commit 7608a51

Browse files
authored
Merge pull request #297 from sylwiaszunejko/shutdown_hanging
Close pending connections during shutdown
2 parents 4fc60e7 + 67a108e commit 7608a51

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

cassandra/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,13 +1691,13 @@ def set_max_connections_per_host(self, host_distance, max_connections):
16911691
"when using protocol_version 1 or 2.")
16921692
self._max_connections_per_host[host_distance] = max_connections
16931693

1694-
def connection_factory(self, endpoint, *args, **kwargs):
1694+
def connection_factory(self, endpoint, host_conn = None, *args, **kwargs):
16951695
"""
16961696
Called to create a new connection with proper configuration.
16971697
Intended for internal use only.
16981698
"""
16991699
kwargs = self._make_connection_kwargs(endpoint, kwargs)
1700-
return self.connection_class.factory(endpoint, self.connect_timeout, *args, **kwargs)
1700+
return self.connection_class.factory(endpoint, self.connect_timeout, host_conn, *args, **kwargs)
17011701

17021702
def _make_connection_factory(self, host, *args, **kwargs):
17031703
kwargs = self._make_connection_kwargs(host.endpoint, kwargs)

cassandra/connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def create_timer(cls, timeout, callback):
865865
raise NotImplementedError()
866866

867867
@classmethod
868-
def factory(cls, endpoint, timeout, *args, **kwargs):
868+
def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
869869
"""
870870
A factory function which returns connections which have
871871
succeeded in connecting and are ready for service (or
@@ -874,6 +874,10 @@ def factory(cls, endpoint, timeout, *args, **kwargs):
874874
start = time.time()
875875
kwargs['connect_timeout'] = timeout
876876
conn = cls(endpoint, *args, **kwargs)
877+
if host_conn is not None:
878+
host_conn._pending_connections.append(conn)
879+
if host_conn.is_shutdown:
880+
conn.close()
877881
elapsed = time.time() - start
878882
conn.connected_event.wait(timeout - elapsed)
879883
if conn.last_error:

cassandra/pool.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def __init__(self, host, host_distance, session):
404404
self._is_replacing = False
405405
self._connecting = set()
406406
self._connections = {}
407+
self._pending_connections = []
407408
# A pool of additional connections which are not used but affect how Scylla
408409
# assigns shards to them. Scylla tends to assign the shard which has
409410
# the lowest number of connections. If connections are not distributed
@@ -638,7 +639,9 @@ def shutdown(self):
638639
future.cancel()
639640

640641
connections_to_close = self._connections.copy()
642+
pending_connections_to_close = self._pending_connections.copy()
641643
self._connections.clear()
644+
self._pending_connections.clear()
642645

643646
# connection.close can call pool.return_connection, which will
644647
# obtain self._lock via self._stream_available_condition.
@@ -647,6 +650,10 @@ def shutdown(self):
647650
log.debug("Closing connection (%s) to %s", id(connection), self.host)
648651
connection.close()
649652

653+
for connection in pending_connections_to_close:
654+
log.debug("Closing pending connection (%s) to %s", id(connection), self.host)
655+
connection.close()
656+
650657
self._close_excess_connections()
651658

652659
trash_conns = None
@@ -714,12 +721,12 @@ def _open_connection_to_missing_shard(self, shard_id):
714721
log.debug("shard_aware_endpoint=%r", shard_aware_endpoint)
715722

716723
if shard_aware_endpoint:
717-
conn = self._session.cluster.connection_factory(shard_aware_endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released,
724+
conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released,
718725
shard_id=shard_id,
719726
total_shards=self.host.sharding_info.shards_count)
720727
conn.original_endpoint = self.host.endpoint
721728
else:
722-
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
729+
conn = self._session.cluster.connection_factory(self.host.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released)
723730

724731
log.debug("Received a connection %s for shard_id=%i on host %s", id(conn), conn.features.shard_id, self.host)
725732
if self.is_shutdown:

0 commit comments

Comments
 (0)