diff --git a/rsocket/stream_control.py b/rsocket/stream_control.py index f444e1b7..53263195 100644 --- a/rsocket/stream_control.py +++ b/rsocket/stream_control.py @@ -13,7 +13,7 @@ class StreamControl: def __init__(self, first_stream_id: int): - self._first_stream_id = first_stream_id + self._first_stream_id = (first_stream_id - 2) & MAX_STREAM_ID self._current_stream_id = self._first_stream_id self._streams: Dict[int, StreamHandler] = {} self._maximum_stream_id = MAX_STREAM_ID @@ -21,15 +21,19 @@ def __init__(self, first_stream_id: int): def allocate_stream(self) -> int: attempt_counter = 0 - while (self._current_stream_id == CONNECTION_STREAM_ID - or self._current_stream_id in self._streams): - + available_stream_id_found = False + while not available_stream_id_found: if attempt_counter > self._maximum_stream_id / 2: raise RSocketStreamAllocationFailure() self._increment_stream_id() attempt_counter += 1 + available_stream_id_found = not ( + self._current_stream_id == CONNECTION_STREAM_ID + or self._current_stream_id in self._streams + ) + return self._current_stream_id def _increment_stream_id(self): diff --git a/tests/rsocket/test_stream_control.py b/tests/rsocket/test_stream_control.py index b4b93b4d..66a51965 100644 --- a/tests/rsocket/test_stream_control.py +++ b/tests/rsocket/test_stream_control.py @@ -66,6 +66,20 @@ def test_stream_control_reuse_old_stream_ids(): assert next_stream == 5 +@pytest.mark.parametrize('first_stream_id', (1, 2)) +def test_stream_id_increments_after_allocation_and_registration_followed_by_finishing(first_stream_id: int): + control = StreamControl(first_stream_id) + dummy_stream = object() + + allocated_id = control.allocate_stream() + control.register_stream(allocated_id, dummy_stream) + + control.finish_stream(allocated_id) + new_allocated_id = control.allocate_stream() + + assert new_allocated_id != allocated_id + + def test_stream_in_use(): control = StreamControl(1)