Skip to content

Commit 975175e

Browse files
committed
Add unit-tests for ShardAwarePortGenerator
1. Make it testable 2. Add unit tests for it
1 parent 4f2312f commit 975175e

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

cassandra/connection.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import weakref
2929
import random
3030
import itertools
31-
from typing import Optional
31+
from typing import Optional, Generator
3232

3333
from cassandra.application_info import ApplicationInfoBase
3434
from cassandra.protocol_features import ProtocolFeatures
@@ -668,17 +668,22 @@ def reset_cql_frame_buffer(self):
668668
self.reset_io_buffer()
669669

670670

671-
class ShardawarePortGenerator:
672-
@classmethod
673-
def generate(cls, shard_id, total_shards):
674-
start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
675-
available_ports = itertools.chain(range(start, DEFAULT_LOCAL_PORT_HIGH), range(DEFAULT_LOCAL_PORT_LOW, start))
671+
class ShardAwarePortGenerator:
672+
def __init__(self, start_port: int, end_port: int):
673+
self.start_port = start_port
674+
self.end_port = end_port
676675

676+
def generate(self, shard_id, total_shards):
677+
start = random.randrange(self.start_port, self.end_port)
678+
available_ports = itertools.chain(range(start, self.end_port), range(self.start_port, start))
677679
for port in available_ports:
678680
if port % total_shards == shard_id:
679681
yield port
680682

681683

684+
DefaultShardawarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
685+
686+
682687
class Connection(object):
683688

684689
CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -934,7 +939,7 @@ def _wrap_socket_from_context(self):
934939

935940
def _initiate_connection(self, sockaddr):
936941
if self.features.shard_id is not None:
937-
for port in ShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
942+
for port in DefaultShardawarePortGenerator.generate(self.features.shard_id, self.total_shards):
938943
try:
939944
self._socket.bind(('', port))
940945
break

tests/unit/test_connection.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
import unittest
1516
from io import BytesIO
1617
import time
@@ -21,7 +22,7 @@
2122
from cassandra.cluster import Cluster
2223
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
2324
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
24-
ConnectionException, DefaultEndPoint)
25+
ConnectionException, DefaultEndPoint, ShardAwarePortGenerator)
2526
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
2627
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
2728
SupportedMessage, ProtocolHandler)
@@ -486,3 +487,43 @@ def test_endpoint_resolve(self):
486487
DefaultEndPoint('10.0.0.1', 3232).resolve(),
487488
('10.0.0.1', 3232)
488489
)
490+
491+
492+
class TestShardawarePortGenerator(unittest.TestCase):
493+
@patch('random.randrange')
494+
def test_generate_ports_basic(self, mock_randrange):
495+
mock_randrange.return_value = 10005
496+
gen = ShardAwarePortGenerator(10000, 10020)
497+
ports = list(itertools.islice(gen.generate(shard_id=1, total_shards=3), 5))
498+
499+
# Starting from aligned 10005 + shard_id (1), step by 3
500+
self.assertEqual(ports, [10006, 10009, 10012, 10015, 10018])
501+
502+
@patch('random.randrange')
503+
def test_wraps_around_to_start(self, mock_randrange):
504+
mock_randrange.return_value = 10008
505+
gen = ShardAwarePortGenerator(10000, 10020)
506+
ports = list(itertools.islice(gen.generate(shard_id=2, total_shards=4), 5))
507+
508+
# Expected wrap-around from start_port after end_port is exceeded
509+
self.assertEqual(ports, [10010, 10014, 10018, 10002, 10006])
510+
511+
@patch('random.randrange')
512+
def test_all_ports_have_correct_modulo(self, mock_randrange):
513+
mock_randrange.return_value = 10012
514+
total_shards = 5
515+
shard_id = 3
516+
gen = ShardAwarePortGenerator(10000, 10020)
517+
518+
for port in gen.generate(shard_id=shard_id, total_shards=total_shards):
519+
self.assertEqual(port % total_shards, shard_id)
520+
521+
@patch('random.randrange')
522+
def test_generate_is_repeatable_with_same_mock(self, mock_randrange):
523+
mock_randrange.return_value = 10010
524+
gen = ShardAwarePortGenerator(10000, 10020)
525+
526+
first_run = list(itertools.islice(gen.generate(0, 2), 5))
527+
second_run = list(itertools.islice(gen.generate(0, 2), 5))
528+
529+
self.assertEqual(first_run, second_run)

0 commit comments

Comments
 (0)