1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import itertools
1415import unittest
1516from io import BytesIO
1617import time
2122from cassandra .cluster import Cluster
2223from cassandra .connection import (Connection , HEADER_DIRECTION_TO_CLIENT , ProtocolError ,
2324 locally_supported_compressions , ConnectionHeartbeat , _Frame , Timer , TimerManager ,
24- ConnectionException , DefaultEndPoint )
25+ ConnectionException , DefaultEndPoint , ShardAwarePortGenerator )
2526from cassandra .marshal import uint8_pack , uint32_pack , int32_pack
2627from cassandra .protocol import (write_stringmultimap , write_int , write_string ,
2728 SupportedMessage , ProtocolHandler )
@@ -486,3 +487,43 @@ def test_endpoint_resolve(self):
486487 DefaultEndPoint ('10.0.0.1' , 3232 ).resolve (),
487488 ('10.0.0.1' , 3232 )
488489 )
490+
491+
492+ class TestShardawarePortGenerator (unittest .TestCase ):
493+ @patch ('random.randrange' )
494+ def test_generate_ports_basic (self , mock_randrange ):
495+ mock_randrange .return_value = 10005
496+ gen = ShardAwarePortGenerator (10000 , 10020 )
497+ ports = list (itertools .islice (gen .generate (shard_id = 1 , total_shards = 3 ), 5 ))
498+
499+ # Starting from aligned 10005 + shard_id (1), step by 3
500+ self .assertEqual (ports , [10006 , 10009 , 10012 , 10015 , 10018 ])
501+
502+ @patch ('random.randrange' )
503+ def test_wraps_around_to_start (self , mock_randrange ):
504+ mock_randrange .return_value = 10008
505+ gen = ShardAwarePortGenerator (10000 , 10020 )
506+ ports = list (itertools .islice (gen .generate (shard_id = 2 , total_shards = 4 ), 5 ))
507+
508+ # Expected wrap-around from start_port after end_port is exceeded
509+ self .assertEqual (ports , [10010 , 10014 , 10018 , 10002 , 10006 ])
510+
511+ @patch ('random.randrange' )
512+ def test_all_ports_have_correct_modulo (self , mock_randrange ):
513+ mock_randrange .return_value = 10012
514+ total_shards = 5
515+ shard_id = 3
516+ gen = ShardAwarePortGenerator (10000 , 10020 )
517+
518+ for port in gen .generate (shard_id = shard_id , total_shards = total_shards ):
519+ self .assertEqual (port % total_shards , shard_id )
520+
521+ @patch ('random.randrange' )
522+ def test_generate_is_repeatable_with_same_mock (self , mock_randrange ):
523+ mock_randrange .return_value = 10010
524+ gen = ShardAwarePortGenerator (10000 , 10020 )
525+
526+ first_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
527+ second_run = list (itertools .islice (gen .generate (0 , 2 ), 5 ))
528+
529+ self .assertEqual (first_run , second_run )
0 commit comments