Skip to content

Commit 3a85b30

Browse files
committed
feat(omni_coordinator): add coordinator module and corresponding unit test
Signed-off-by: Jeff Wan <wantszkin2003@gmail.com> update coordinator Signed-off-by: Jeff Wan <wantszkin2003@gmail.com> update coordinator Signed-off-by: Jeff Wan <wantszkin2003@gmail.com> bugdfix clientForStage Signed-off-by: Jeff Wan <wantszkin2003@gmail.com> bugdfix clientForStage Signed-off-by: Jeff Wan <wantszkin2003@gmail.com> bugdfix omni_coordinator Signed-off-by: Jeff Wan <wantszkin2003@gmail.com>
1 parent 4de077e commit 3a85b30

File tree

10 files changed

+1214
-0
lines changed

10 files changed

+1214
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from time import time
5+
6+
from vllm_omni.distributed.omni_coordinator import (
7+
InstanceInfo,
8+
RandomBalancer,
9+
StageStatus,
10+
)
11+
12+
13+
def test_load_balancer_select_returns_valid_index():
14+
"""Verify RandomBalancer.select() returns a valid index for instances."""
15+
# Dummy task object; RandomBalancer ignores task contents.
16+
class DummyTask:
17+
pass
18+
19+
task = DummyTask()
20+
21+
now = time()
22+
instances = [
23+
InstanceInfo(
24+
zmq_addr="tcp://host:10001",
25+
stage_id=0,
26+
status=StageStatus.UP,
27+
queue_length=0,
28+
last_heartbeat=now,
29+
registered_at=now,
30+
),
31+
InstanceInfo(
32+
zmq_addr="tcp://host:10002",
33+
stage_id=0,
34+
status=StageStatus.UP,
35+
queue_length=1,
36+
last_heartbeat=now,
37+
registered_at=now,
38+
),
39+
InstanceInfo(
40+
zmq_addr="tcp://host:10003",
41+
stage_id=1,
42+
status=StageStatus.UP,
43+
queue_length=2,
44+
last_heartbeat=now,
45+
registered_at=now,
46+
),
47+
]
48+
49+
balancer = RandomBalancer()
50+
51+
index = balancer.select(task, instances)
52+
53+
assert isinstance(index, int)
54+
assert 0 <= index < len(instances)
55+
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
import time
6+
7+
import pytest
8+
import zmq
9+
10+
from vllm_omni.distributed.omni_coordinator import (
11+
InstanceInfo,
12+
InstanceList,
13+
OmniCoordClientForHub,
14+
StageStatus,
15+
)
16+
17+
18+
def _bind_pub() -> tuple[zmq.Context, zmq.Socket, str]:
19+
ctx = zmq.Context.instance()
20+
pub = ctx.socket(zmq.PUB)
21+
pub.bind("tcp://127.0.0.1:*")
22+
endpoint = pub.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
23+
return ctx, pub, endpoint
24+
25+
26+
def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> bool:
27+
start = time.time()
28+
while time.time() - start < timeout:
29+
if cond():
30+
return True
31+
time.sleep(interval)
32+
return False
33+
34+
35+
def test_hub_client_caches_instance_list_from_pub():
36+
"""Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list()."""
37+
ctx, pub, endpoint = _bind_pub()
38+
39+
client = OmniCoordClientForHub(endpoint)
40+
# ZMQ PUB/SUB slow-joiner: allow SUB to finish connecting before first send
41+
time.sleep(0.2)
42+
43+
now = time.time()
44+
instances_payload = [
45+
{
46+
"zmq_addr": "tcp://stage:10001",
47+
"stage_id": 0,
48+
"status": "up",
49+
"queue_length": 0,
50+
"last_heartbeat": now,
51+
"registered_at": now,
52+
},
53+
{
54+
"zmq_addr": "tcp://stage:10002",
55+
"stage_id": 0,
56+
"status": "up",
57+
"queue_length": 1,
58+
"last_heartbeat": now,
59+
"registered_at": now,
60+
},
61+
{
62+
"zmq_addr": "tcp://stage:10003",
63+
"stage_id": 1,
64+
"status": "error",
65+
"queue_length": 5,
66+
"last_heartbeat": now,
67+
"registered_at": now,
68+
},
69+
]
70+
71+
payload = {"instances": instances_payload, "timestamp": now}
72+
pub.send(json.dumps(payload).encode("utf-8"))
73+
74+
assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3)
75+
76+
inst_list = client.get_instance_list()
77+
assert isinstance(inst_list, InstanceList)
78+
assert len(inst_list.instances) == 3
79+
80+
for src, inst in zip(instances_payload, inst_list.instances, strict=True):
81+
assert inst.zmq_addr == src["zmq_addr"]
82+
assert inst.stage_id == src["stage_id"]
83+
assert inst.status.value == src["status"]
84+
85+
stage0 = client.get_instances_for_stage(0)
86+
stage1 = client.get_instances_for_stage(1)
87+
88+
assert all(inst.stage_id == 0 for inst in stage0.instances)
89+
assert all(inst.stage_id == 1 for inst in stage1.instances)
90+
91+
# Send an updated list with fewer instances and verify cache refresh.
92+
updated_payload = {
93+
"instances": instances_payload[:2],
94+
"timestamp": now + 1.0,
95+
}
96+
pub.send(json.dumps(updated_payload).encode("utf-8"))
97+
98+
assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2)
99+
updated_list = client.get_instance_list()
100+
assert len(updated_list.instances) == 2
101+
102+
client.close()
103+
pub.close(0)
104+
ctx.term()
105+
106+
107+
def test_hub_client_close_closes_sub_socket():
108+
"""Verify OmniCoordClientForHub.close() marks client as closed; second close raises."""
109+
ctx, pub, endpoint = _bind_pub()
110+
client = OmniCoordClientForHub(endpoint)
111+
client.close()
112+
113+
with pytest.raises(RuntimeError, match="already closed"):
114+
client.close()
115+
116+
pub.close(0)
117+
ctx.term()
118+
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
6+
import zmq
7+
8+
from vllm_omni.distributed.omni_coordinator import (
9+
OmniCoordClientForStage,
10+
StageStatus,
11+
)
12+
13+
14+
def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]:
15+
ctx = zmq.Context.instance()
16+
router = ctx.socket(zmq.ROUTER)
17+
router.bind("tcp://127.0.0.1:*")
18+
endpoint = router.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
19+
return ctx, router, endpoint
20+
21+
22+
def _recv_event(router: zmq.Socket) -> dict:
23+
frames = router.recv_multipart()
24+
# ROUTER adds identity frame; the last frame is the payload.
25+
payload = frames[-1]
26+
return json.loads(payload.decode("utf-8"))
27+
28+
29+
def test_stage_client_auto_register_on_init():
30+
"""Verify OmniCoordClientForStage automatically sends initial registration/status-up event when created."""
31+
ctx, router, endpoint = _bind_router()
32+
33+
instance_addr = "tcp://stage:10001"
34+
stage_id = 0
35+
36+
client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)
37+
38+
event = _recv_event(router)
39+
40+
assert event["event_type"] == "update"
41+
assert event["status"] == StageStatus.UP.value
42+
assert event["stage_id"] == stage_id
43+
assert event["zmq_addr"] == instance_addr
44+
45+
client.close()
46+
router.close(0)
47+
ctx.term()
48+
49+
50+
def test_stage_client_update_info_sends_correct_event():
51+
"""Verify OmniCoordClientForStage.update_info() sends status/load update events with expected fields."""
52+
ctx, router, endpoint = _bind_router()
53+
54+
instance_addr = "tcp://stage:10002"
55+
stage_id = 1
56+
57+
client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)
58+
59+
# Discard initial registration event.
60+
_recv_event(router)
61+
62+
client.update_info(status=StageStatus.ERROR)
63+
client.update_info(queue_length=10)
64+
65+
first = _recv_event(router)
66+
second = _recv_event(router)
67+
68+
assert first["status"] == StageStatus.ERROR.value
69+
assert first["stage_id"] == stage_id
70+
assert first["zmq_addr"] == instance_addr
71+
72+
assert second["queue_length"] == 10
73+
assert second["stage_id"] == stage_id
74+
assert second["zmq_addr"] == instance_addr
75+
76+
client.close()
77+
router.close(0)
78+
ctx.term()
79+
80+
81+
def test_stage_client_close_sends_down_status():
82+
"""Verify close() sends final status-down event before closing underlying socket."""
83+
ctx, router, endpoint = _bind_router()
84+
85+
instance_addr = "tcp://stage:10003"
86+
stage_id = 2
87+
88+
client = OmniCoordClientForStage(endpoint, instance_addr, stage_id)
89+
90+
# Discard initial registration event.
91+
_recv_event(router)
92+
93+
client.close()
94+
95+
event = _recv_event(router)
96+
assert event["status"] == StageStatus.DOWN.value
97+
assert event["stage_id"] == stage_id
98+
assert event["zmq_addr"] == instance_addr
99+
100+
assert client._socket.closed # DEALER socket no longer usable after close
101+
102+
router.close(0)
103+
ctx.term()
104+

0 commit comments

Comments
 (0)