1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import uuid
15+ from unittest .mock import Mock
16+
17+ from cassandra .policies import NoDelayShardConnectionBackoffPolicy , _NoDelayShardConnectionBackoffScheduler
1418
1519try :
1620 import unittest2 as unittest
2125from mock import MagicMock
2226from concurrent .futures import ThreadPoolExecutor
2327
24- from cassandra .cluster import ShardAwareOptions
28+ from cassandra .cluster import ShardAwareOptions , _Scheduler
2529from cassandra .pool import HostConnection , HostDistance
2630from cassandra .connection import ShardingInfo , DefaultEndPoint
2731from cassandra .metadata import Murmur3Token
@@ -53,11 +57,18 @@ class OptionsHolder(object):
5357 self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"e" ).value ), 4 )
5458 self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"100000" ).value ), 2 )
5559
56- def test_advanced_shard_aware_port (self ):
60+ def test_shard_aware_reconnection_policy_no_delay (self ):
61+ # with NoDelayReconnectionPolicy all the connections should be created right away
62+ self ._test_shard_aware_reconnection_policy (4 , NoDelayShardConnectionBackoffPolicy (), 4 )
63+
64+ def _test_shard_aware_reconnection_policy (self , shard_count , shard_connection_backoff_policy , expected_connections ):
5765 """
5866 Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
59- the next connections would be open using this port
67+ It checks that:
68+ 1. Next connections are opened using this port
69+ 2. Connection creation pase matches `shard_connection_backoff_policy`
6070 """
71+
6172 class MockSession (MagicMock ):
6273 is_shutdown = False
6374 keyspace = "ks1"
@@ -71,45 +82,82 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7182 self .cluster .ssl_options = None
7283 self .cluster .shard_aware_options = ShardAwareOptions ()
7384 self .cluster .executor = ThreadPoolExecutor (max_workers = 2 )
85+ self ._executor_submit_original = self .cluster .executor .submit
86+ self .cluster .executor .submit = self ._executor_submit
87+ self .cluster .scheduler = _Scheduler (self .cluster .executor )
88+
89+ # Collect scheduled calls and execute them right away
90+ self .scheduler_calls = []
91+ original_schedule = self .cluster .scheduler .schedule
92+
93+ def new_schedule (delay , fn , * args , ** kwargs ):
94+ self .scheduler_calls .append ((delay , fn , args , kwargs ))
95+ return original_schedule (0 , fn , * args , ** kwargs )
96+
97+ self .cluster .scheduler .schedule = Mock (side_effect = new_schedule )
7498 self .cluster .signal_connection_failure = lambda * args , ** kwargs : False
7599 self .cluster .connection_factory = self .mock_connection_factory
76100 self .connection_counter = 0
101+ self .shard_connection_backoff_scheduler = shard_connection_backoff_policy .new_connection_scheduler (
102+ self .cluster .scheduler )
77103 self .futures = []
78104
79105 def submit (self , fn , * args , ** kwargs ):
106+ if self .is_shutdown :
107+ return None
108+ return self .cluster .executor .submit (fn , * args , ** kwargs )
109+
110+ def _executor_submit (self , fn , * args , ** kwargs ):
80111 logging .info ("Scheduling %s with args: %s, kwargs: %s" , fn , args , kwargs )
81- if not self .is_shutdown :
82- f = self .cluster .executor .submit (fn , * args , ** kwargs )
83- self .futures += [f ]
84- return f
112+ f = self ._executor_submit_original (fn , * args , ** kwargs )
113+ self .futures += [f ]
114+ return f
85115
86116 def mock_connection_factory (self , * args , ** kwargs ):
87117 connection = MagicMock ()
88118 connection .is_shutdown = False
89119 connection .is_defunct = False
90120 connection .is_closed = False
91121 connection .orphaned_threshold_reached = False
92- connection .endpoint = args [0 ]
93- sharding_info = ShardingInfo (shard_id = 1 , shards_count = 4 , partitioner = "" , sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 , shard_aware_port_ssl = 19045 )
94- connection .features = ProtocolFeatures (shard_id = kwargs .get ('shard_id' , self .connection_counter ), sharding_info = sharding_info )
122+ connection .endpoint = args [0 ]
123+ sharding_info = None
124+ if shard_count :
125+ sharding_info = ShardingInfo (shard_id = 1 , shards_count = shard_count , partitioner = "" ,
126+ sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 ,
127+ shard_aware_port_ssl = 19045 )
128+ connection .features = ProtocolFeatures (
129+ shard_id = kwargs .get ('shard_id' , self .connection_counter ),
130+ sharding_info = sharding_info )
95131 self .connection_counter += 1
96132
97133 return connection
98134
99135 host = MagicMock ()
136+ host .host_id = uuid .uuid4 ()
100137 host .endpoint = DefaultEndPoint ("1.2.3.4" )
138+ session = None
139+ try :
140+ for port , is_ssl in [(19042 , False ), (19045 , True )]:
141+ session = MockSession (is_ssl = is_ssl )
142+ pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
143+ for f in session .futures :
144+ f .result ()
145+ assert len (pool ._connections ) == expected_connections
146+ for shard_id , connection in pool ._connections .items ():
147+ assert connection .features .shard_id == shard_id
148+ if shard_id == 0 :
149+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
150+ else :
151+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
101152
102- for port , is_ssl in [(19042 , False ), (19045 , True )]:
103- session = MockSession (is_ssl = is_ssl )
104- pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
105- for f in session .futures :
106- f .result ()
107- assert len (pool ._connections ) == 4
108- for shard_id , connection in pool ._connections .items ():
109- assert connection .features .shard_id == shard_id
110- if shard_id == 0 :
111- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
112- else :
113- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
114-
115- session .cluster .executor .shutdown (wait = True )
153+ sleep_time = 0
154+ found_related_calls = 0
155+ for delay , fn , args , kwargs in session .scheduler_calls :
156+ if fn .__self__ .__class__ is _NoDelayShardConnectionBackoffScheduler :
157+ found_related_calls += 1
158+ self .assertEqual (delay , sleep_time )
159+ self .assertLessEqual (shard_count - 1 , found_related_calls )
160+ finally :
161+ if session :
162+ session .cluster .scheduler .shutdown ()
163+ session .cluster .executor .shutdown (wait = True )
0 commit comments