Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/distributed/omni_coordinator/test_load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from time import time

from vllm_omni.distributed.omni_coordinator import (
InstanceInfo,
RandomBalancer,
StageStatus,
)


def test_load_balancer_select_returns_valid_index():
"""Verify RandomBalancer.select() returns a valid index for instances."""
# Task structure mirrors async_omni; RandomBalancer ignores task contents.
task: dict = {
"request_id": "test",
"engine_inputs": None,
"sampling_params": None,
}

now = time()
instances = [
InstanceInfo(
zmq_addr="tcp://host:10001",
stage_id=0,
status=StageStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
zmq_addr="tcp://host:10002",
stage_id=0,
status=StageStatus.UP,
queue_length=1,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
zmq_addr="tcp://host:10003",
stage_id=1,
status=StageStatus.UP,
queue_length=2,
last_heartbeat=now,
registered_at=now,
),
]

balancer = RandomBalancer()

index = balancer.select(task, instances)

assert isinstance(index, int)
assert 0 <= index < len(instances)
115 changes: 115 additions & 0 deletions tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import time

import pytest
import zmq

from vllm_omni.distributed.omni_coordinator import (
InstanceList,
OmniCoordClientForHub,
)


def _bind_pub() -> tuple[zmq.Context, zmq.Socket, str]:
ctx = zmq.Context.instance()
pub = ctx.socket(zmq.PUB)
pub.bind("tcp://127.0.0.1:*")
endpoint = pub.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
return ctx, pub, endpoint


def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> bool:
start = time.time()
while time.time() - start < timeout:
if cond():
return True
time.sleep(interval)
return False


def test_hub_client_caches_instance_list_from_pub():
"""Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list()."""
ctx, pub, endpoint = _bind_pub()

client = OmniCoordClientForHub(endpoint)
# ZMQ PUB/SUB slow-joiner: allow SUB to finish connecting before first send
time.sleep(0.2)

now = time.time()
instances_payload = [
{
"zmq_addr": "tcp://stage:10001",
"stage_id": 0,
"status": "up",
"queue_length": 0,
"last_heartbeat": now,
"registered_at": now,
},
{
"zmq_addr": "tcp://stage:10002",
"stage_id": 0,
"status": "up",
"queue_length": 1,
"last_heartbeat": now,
"registered_at": now,
},
{
"zmq_addr": "tcp://stage:10003",
"stage_id": 1,
"status": "error",
"queue_length": 5,
"last_heartbeat": now,
"registered_at": now,
},
]

payload = {"instances": instances_payload, "timestamp": now}
pub.send(json.dumps(payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3)

inst_list = client.get_instance_list()
assert isinstance(inst_list, InstanceList)
assert len(inst_list.instances) == 3

for src, inst in zip(instances_payload, inst_list.instances, strict=True):
assert inst.zmq_addr == src["zmq_addr"]
assert inst.stage_id == src["stage_id"]
assert inst.status.value == src["status"]

stage0 = client.get_instances_for_stage(0)
stage1 = client.get_instances_for_stage(1)

assert all(inst.stage_id == 0 for inst in stage0.instances)
assert all(inst.stage_id == 1 for inst in stage1.instances)

# Send an updated list with fewer instances and verify cache refresh.
updated_payload = {
"instances": instances_payload[:2],
"timestamp": now + 1.0,
}
pub.send(json.dumps(updated_payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2)
updated_list = client.get_instance_list()
assert len(updated_list.instances) == 2

client.close()
pub.close(0)
ctx.term()


def test_hub_client_close_closes_sub_socket():
"""Verify OmniCoordClientForHub.close() marks client as closed; second close raises."""
ctx, pub, endpoint = _bind_pub()
client = OmniCoordClientForHub(endpoint)
client.close()

with pytest.raises(RuntimeError, match="already closed"):
client.close()

pub.close(0)
ctx.term()
103 changes: 103 additions & 0 deletions tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import zmq

from vllm_omni.distributed.omni_coordinator import (
OmniCoordClientForStage,
StageStatus,
)


def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]:
ctx = zmq.Context.instance()
router = ctx.socket(zmq.ROUTER)
router.bind("tcp://127.0.0.1:*")
endpoint = router.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
return ctx, router, endpoint


def _recv_event(router: zmq.Socket) -> dict:
frames = router.recv_multipart()
# ROUTER adds identity frame; the last frame is the payload.
payload = frames[-1]
return json.loads(payload.decode("utf-8"))


def test_stage_client_auto_register_on_init():
"""Verify OmniCoordClientForStage automatically sends initial registration/status-up event when created."""
ctx, router, endpoint = _bind_router()

instance_addr = "tcp://stage:10001"
stage_id = 0

client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)

event = _recv_event(router)

assert event["event_type"] == "update"
assert event["status"] == StageStatus.UP.value
assert event["stage_id"] == stage_id
assert event["zmq_addr"] == instance_addr

client.close()
router.close(0)
ctx.term()


def test_stage_client_update_info_sends_correct_event():
"""Verify OmniCoordClientForStage.update_info() sends status/load update events with expected fields."""
ctx, router, endpoint = _bind_router()

instance_addr = "tcp://stage:10002"
stage_id = 1

client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)

# Discard initial registration event.
_recv_event(router)

client.update_info(status=StageStatus.ERROR)
client.update_info(queue_length=10)

first = _recv_event(router)
second = _recv_event(router)

assert first["status"] == StageStatus.ERROR.value
assert first["stage_id"] == stage_id
assert first["zmq_addr"] == instance_addr

assert second["queue_length"] == 10
assert second["stage_id"] == stage_id
assert second["zmq_addr"] == instance_addr

client.close()
router.close(0)
ctx.term()


def test_stage_client_close_sends_down_status():
"""Verify close() sends final status-down event before closing underlying socket."""
ctx, router, endpoint = _bind_router()

instance_addr = "tcp://stage:10003"
stage_id = 2

client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)

# Discard initial registration event.
_recv_event(router)

client.close()

event = _recv_event(router)
assert event["status"] == StageStatus.DOWN.value
assert event["stage_id"] == stage_id
assert event["zmq_addr"] == instance_addr

assert client._socket.closed # DEALER socket no longer usable after close

router.close(0)
ctx.term()
Loading