1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import itertools
1415import unittest
1516from io import BytesIO
1617import time
2122from cassandra .cluster import Cluster
2223from cassandra .connection import (Connection , HEADER_DIRECTION_TO_CLIENT , ProtocolError ,
2324 locally_supported_compressions , ConnectionHeartbeat , _Frame , Timer , TimerManager ,
24- ConnectionException , DefaultEndPoint )
25+ ConnectionException , DefaultEndPoint , ShardAwarePortGenerator )
2526from cassandra .marshal import uint8_pack , uint32_pack , int32_pack
2627from cassandra .protocol import (write_stringmultimap , write_int , write_string ,
2728 SupportedMessage , ProtocolHandler )
@@ -478,3 +479,43 @@ def test_endpoint_resolve(self):
478479 DefaultEndPoint ('10.0.0.1' , 3232 ).resolve (),
479480 ('10.0.0.1' , 3232 )
480481 )
482+
483+
484+ class TestShardawarePortGenerator (unittest .TestCase ):
485+ @patch ('random.randrange' )
486+ def test_generate_ports_basic (self , mock_randrange ):
487+ mock_randrange .return_value = 10005
488+ gen = ShardAwarePortGenerator (10000 , 10020 )
489+ ports = list (itertools .islice (gen .generate (shard_id = 1 , total_shards = 3 ), 5 ))
490+
491+ # Starting from aligned 10005 + shard_id (1), step by 3
492+ self .assertEqual (ports , [10006 , 10009 , 10012 , 10015 , 10018 ])
493+
494+ @patch ('random.randrange' )
495+ def test_wraps_around_to_start (self , mock_randrange ):
496+ mock_randrange .return_value = 10008
497+ gen = ShardAwarePortGenerator (10000 , 10020 )
498+ ports = list (itertools .islice (gen .generate (shard_id = 2 , total_shards = 4 ), 5 ))
499+
500+ # Expected wrap-around from start_port after end_port is exceeded
501+ self .assertEqual (ports , [10010 , 10014 , 10018 , 10002 , 10006 ])
502+
503+ @patch ('random.randrange' )
504+ def test_all_ports_have_correct_modulo (self , mock_randrange ):
505+ mock_randrange .return_value = 10012
506+ total_shards = 5
507+ shard_id = 3
508+ gen = ShardAwarePortGenerator (10000 , 10020 )
509+
510+ for port in gen .generate (shard_id = shard_id , total_shards = total_shards ):
511+ self .assertEqual (port % total_shards , shard_id )
512+
513+ @patch ('random.randrange' )
514+ def test_generate_is_repeatable_with_same_mock (self , mock_randrange ):
515+ mock_randrange .return_value = 10010
516+ gen = ShardAwarePortGenerator (10000 , 10020 )
517+
518+ first_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
519+ second_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
520+
521+ self .assertEqual (first_run , second_run )
0 commit comments