1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import uuid
1415
1516try :
1617 import unittest2 as unittest
1718except ImportError :
1819 import unittest # noqa
1920
21+ import time
2022import logging
2123from mock import MagicMock
2224from concurrent .futures import ThreadPoolExecutor
2325
24- from cassandra .cluster import ShardAwareOptions
26+ from cassandra .cluster import ShardAwareOptions , _Scheduler
27+ from cassandra .policies import ConstantReconnectionPolicy , \
28+ NoDelayShardReconnectionPolicy , NoConcurrentShardReconnectionPolicy , ShardReconnectionPolicyScope
2529from cassandra .pool import HostConnection , HostDistance
2630from cassandra .connection import ShardingInfo , DefaultEndPoint
2731from cassandra .metadata import Murmur3Token
@@ -53,7 +57,15 @@ class OptionsHolder(object):
5357 self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"e" ).value ), 4 )
5458 self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"100000" ).value ), 2 )
5559
56- def test_advanced_shard_aware_port (self ):
60+ def test_shard_aware_reconnection_policy_no_delay (self ):
61+ # with NoDelayReconnectionPolicy all the connections should be created right away
62+ self ._test_shard_aware_reconnection_policy (4 , NoDelayShardReconnectionPolicy (), 4 , 4 )
63+
64+ def test_shard_aware_reconnection_policy_delay (self ):
65+ # with ConstantReconnectionPolicy first connection is created right away, others are delayed
66+ self ._test_shard_aware_reconnection_policy (4 , NoConcurrentShardReconnectionPolicy (ShardReconnectionPolicyScope .Cluster , ConstantReconnectionPolicy (1 )), 1 , 4 )
67+
68+ def _test_shard_aware_reconnection_policy (self , shard_count , shard_reconnection_policy , expected_count , expected_after ):
5769 """
5870 Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
5971 the next connections would be open using this port
@@ -71,17 +83,25 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7183 self .cluster .ssl_options = None
7284 self .cluster .shard_aware_options = ShardAwareOptions ()
7385 self .cluster .executor = ThreadPoolExecutor (max_workers = 2 )
86+ self ._executor_submit_original = self .cluster .executor .submit
87+ self .cluster .executor .submit = self ._executor_submit
88+ self .cluster .scheduler = _Scheduler (self .cluster .executor )
7489 self .cluster .signal_connection_failure = lambda * args , ** kwargs : False
7590 self .cluster .connection_factory = self .mock_connection_factory
7691 self .connection_counter = 0
92+ self .shard_reconnection_scheduler = shard_reconnection_policy .new_scheduler (self )
7793 self .futures = []
7894
7995 def submit (self , fn , * args , ** kwargs ):
96+ if self .is_shutdown :
97+ return None
98+ return self .cluster .executor .submit (fn , * args , ** kwargs )
99+
100+ def _executor_submit (self , fn , * args , ** kwargs ):
80101 logging .info ("Scheduling %s with args: %s, kwargs: %s" , fn , args , kwargs )
81- if not self .is_shutdown :
82- f = self .cluster .executor .submit (fn , * args , ** kwargs )
83- self .futures += [f ]
84- return f
102+ f = self ._executor_submit_original (fn , * args , ** kwargs )
103+ self .futures += [f ]
104+ return f
85105
86106 def mock_connection_factory (self , * args , ** kwargs ):
87107 connection = MagicMock ()
@@ -90,26 +110,50 @@ def mock_connection_factory(self, *args, **kwargs):
90110 connection .is_closed = False
91111 connection .orphaned_threshold_reached = False
92112 connection .endpoint = args [0 ]
93- sharding_info = ShardingInfo (shard_id = 1 , shards_count = 4 , partitioner = "" , sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 , shard_aware_port_ssl = 19045 )
113+ sharding_info = ShardingInfo (shard_id = 1 , shards_count = shard_count , partitioner = "" , sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 , shard_aware_port_ssl = 19045 )
94114 connection .features = ProtocolFeatures (shard_id = kwargs .get ('shard_id' , self .connection_counter ), sharding_info = sharding_info )
95115 self .connection_counter += 1
96116
97117 return connection
98118
99119 host = MagicMock ()
120+ host .host_id = uuid .uuid4 ()
100121 host .endpoint = DefaultEndPoint ("1.2.3.4" )
122+ session = None
123+ reconnection_policy = None
124+ if isinstance (shard_reconnection_policy , NoConcurrentShardReconnectionPolicy ):
125+ reconnection_policy = shard_reconnection_policy .reconnection_policy
126+ try :
127+ for port , is_ssl in [(19042 , False ), (19045 , True )]:
128+ session = MockSession (is_ssl = is_ssl )
129+ pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
130+ for f in session .futures :
131+ f .result ()
132+ assert len (pool ._connections ) == expected_count
133+ for shard_id , connection in pool ._connections .items ():
134+ assert connection .features .shard_id == shard_id
135+ if shard_id == 0 :
136+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
137+ else :
138+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
101139
102- for port , is_ssl in [(19042 , False ), (19045 , True )]:
103- session = MockSession (is_ssl = is_ssl )
104- pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
105- for f in session .futures :
106- f .result ()
107- assert len (pool ._connections ) == 4
108- for shard_id , connection in pool ._connections .items ():
109- assert connection .features .shard_id == shard_id
110- if shard_id == 0 :
111- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
112- else :
113- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
140+ sleep_time = 0
141+ if reconnection_policy :
142+ # Check that connections to shards are being established according to the policy
143+ # Calculate total time it will need to establish all connections
144+ # Sleep half of the time and check that connections are not there yet
145+ # Sleep rest of the time + 1 second and check that all connections has been established
146+ schedule = reconnection_policy .new_schedule ()
147+ for _ in range (shard_count ):
148+ sleep_time += next (schedule )
149+ if sleep_time > 0 :
150+ time .sleep (sleep_time / 2 )
151+ # Check that connection are not being established quicker than expected
152+ assert len (pool ._connections ) < expected_after
153+ time .sleep (sleep_time / 2 + 1 )
114154
115- session .cluster .executor .shutdown (wait = True )
155+ assert len (pool ._connections ) == expected_after
156+ finally :
157+ if session :
158+ session .cluster .scheduler .shutdown ()
159+ session .cluster .executor .shutdown (wait = True )
0 commit comments