@@ -404,6 +404,7 @@ def __init__(self, host, host_distance, session):
404
404
self ._is_replacing = False
405
405
self ._connecting = set ()
406
406
self ._connections = {}
407
+ self ._pending_connections = []
407
408
# A pool of additional connections which are not used but affect how Scylla
408
409
# assigns shards to them. Scylla tends to assign the shard which has
409
410
# the lowest number of connections. If connections are not distributed
@@ -638,7 +639,9 @@ def shutdown(self):
638
639
future .cancel ()
639
640
640
641
connections_to_close = self ._connections .copy ()
642
+ pending_connections_to_close = self ._pending_connections .copy ()
641
643
self ._connections .clear ()
644
+ self ._pending_connections .clear ()
642
645
643
646
# connection.close can call pool.return_connection, which will
644
647
# obtain self._lock via self._stream_available_condition.
@@ -647,6 +650,10 @@ def shutdown(self):
647
650
log .debug ("Closing connection (%s) to %s" , id (connection ), self .host )
648
651
connection .close ()
649
652
653
+ for connection in pending_connections_to_close :
654
+ log .debug ("Closing pending connection (%s) to %s" , id (connection ), self .host )
655
+ connection .close ()
656
+
650
657
self ._close_excess_connections ()
651
658
652
659
trash_conns = None
@@ -714,12 +721,12 @@ def _open_connection_to_missing_shard(self, shard_id):
714
721
log .debug ("shard_aware_endpoint=%r" , shard_aware_endpoint )
715
722
716
723
if shard_aware_endpoint :
717
- conn = self ._session .cluster .connection_factory (shard_aware_endpoint , on_orphaned_stream_released = self .on_orphaned_stream_released ,
724
+ conn = self ._session .cluster .connection_factory (shard_aware_endpoint , host_conn = self , on_orphaned_stream_released = self .on_orphaned_stream_released ,
718
725
shard_id = shard_id ,
719
726
total_shards = self .host .sharding_info .shards_count )
720
727
conn .original_endpoint = self .host .endpoint
721
728
else :
722
- conn = self ._session .cluster .connection_factory (self .host .endpoint , on_orphaned_stream_released = self .on_orphaned_stream_released )
729
+ conn = self ._session .cluster .connection_factory (self .host .endpoint , host_conn = self , on_orphaned_stream_released = self .on_orphaned_stream_released )
723
730
724
731
log .debug ("Received a connection %s for shard_id=%i on host %s" , id (conn ), conn .features .shard_id , self .host )
725
732
if self .is_shutdown :
0 commit comments