Skip to content

Commit 9f414a1

Browse files
authored
[BugFix] Make PD work with Ray (#21072)
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
1 parent 6a971ed commit 9f414a1

File tree

11 files changed

+329
-221
lines changed

11 files changed

+329
-221
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 42 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,47 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
5+
import tempfile
6+
import textwrap
47
import time
5-
import uuid
6-
from collections import defaultdict
7-
from typing import Optional
88
from unittest.mock import patch
99

1010
import pytest
11+
import ray
1112

1213
from vllm import LLM
1314
from vllm.config import KVTransferConfig
1415
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
1516
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
1617
NixlConnectorWorker)
1718
from vllm.forward_context import ForwardContext
19+
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
1820
from vllm.sampling_params import SamplingParams
1921

2022
from .utils import create_request, create_scheduler, create_vllm_config
2123

2224

25+
def _make_stub_pkg() -> str:
26+
"""Return a directory that makes
27+
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
28+
td = tempfile.mkdtemp()
29+
pkg_root = os.path.join(td, "nixl", "_api")
30+
os.makedirs(pkg_root, exist_ok=True)
31+
32+
stub = textwrap.dedent("""\
33+
# Forward the real FakeNixlWrapper that the driver already defined.
34+
print("In fake package")
35+
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
36+
""")
37+
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
38+
f.write(stub)
39+
40+
# touch parent package
41+
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
42+
return td
43+
44+
2345
def test_basic_interface():
2446
"""Unit test for basic NixlConnector interface functionality."""
2547

@@ -87,77 +109,6 @@ def test_prompt_less_than_block_size():
87109
assert len(scheduler_output.scheduled_new_reqs) == 1
88110

89111

90-
class FakeNixlWrapper:
91-
"""Mock implementation of NixlWrapper for testing.
92-
93-
We don't inherit from nixl._api.nixl_agent because nixl may not be
94-
installed.
95-
"""
96-
97-
AGENT_METADATA = b"fake_agent_metadata"
98-
REMOTE_AGENT_NAME = "remote_agent"
99-
100-
def __init__(self, agent_name: str, *args, **kwargs):
101-
self._cycles_before_xfer_done = 0
102-
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
103-
lambda: 0)
104-
105-
def get_reg_descs(self, caches_data, memory_type: str) -> list:
106-
return [str(uuid.uuid4()) for _ in caches_data]
107-
108-
def register_memory(self, descs) -> None:
109-
pass
110-
111-
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
112-
return [str(uuid.uuid4()) for _ in blocks_data]
113-
114-
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
115-
return uuid.uuid4().int
116-
117-
def get_agent_metadata(self) -> bytes:
118-
return self.AGENT_METADATA
119-
120-
def add_remote_agent(self, agent_metadata: bytes) -> str:
121-
return self.REMOTE_AGENT_NAME
122-
123-
def get_new_notifs(self) -> dict[str, list[bytes]]:
124-
# Used to collect done_sending, which we don't test yet.
125-
return {}
126-
127-
def check_xfer_state(self, handle: int) -> str:
128-
if self._check_xfer_state_cycles[
129-
handle] >= self._cycles_before_xfer_done:
130-
return "DONE"
131-
self._check_xfer_state_cycles[handle] += 1
132-
return "PROC"
133-
134-
def release_xfer_handle(self, handle: int) -> None:
135-
pass
136-
137-
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
138-
pass
139-
140-
def make_prepped_xfer(self,
141-
xfer_type: str,
142-
local_xfer_side_handle: int,
143-
local_block_descs_ids: list[int],
144-
remote_xfer_side_handle: int,
145-
remote_block_descs_ids: list[int],
146-
notif_msg: Optional[bytes] = None) -> int:
147-
return uuid.uuid4().int
148-
149-
def transfer(self, handle: int) -> str:
150-
return "PROC"
151-
152-
############################################################
153-
# Follow are for changing the behavior during testing.
154-
############################################################
155-
156-
def set_cycles_before_xfer_done(self, cycles: int):
157-
"""Set the number of cycles before a transfer is considered done."""
158-
self._cycles_before_xfer_done = cycles
159-
160-
161112
class FakeNixlConnectorWorker(NixlConnectorWorker):
162113

163114
REMOTE_ENGINE_ID = "remote_engine"
@@ -378,10 +329,14 @@ def test_concurrent_load_kv(
378329
raise TimeoutError("Took too long to complete async handshake.")
379330

380331

332+
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
333+
# we put here is important. First run ray, it will clean up the resources, then
334+
# the rest of the tests.
335+
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
381336
@patch(
382337
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
383338
FakeNixlWrapper)
384-
def test_abort_timeout_on_prefiller(monkeypatch):
339+
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
385340
"""
386341
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
387342
-----> P
@@ -399,11 +354,23 @@ def test_abort_timeout_on_prefiller(monkeypatch):
399354
timeout = 6
400355
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
401356
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
357+
358+
# Build runtime_env only if we’re using Ray
359+
if distributed_executor_backend == "ray":
360+
runtime_env = {
361+
"working_dir": _make_stub_pkg(), # ship stub package
362+
"env_vars": {
363+
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
364+
},
365+
}
366+
ray.init(runtime_env=runtime_env)
367+
402368
llm = LLM(
403369
model=model_name,
404370
enforce_eager=True,
405371
gpu_memory_utilization=0.5,
406372
kv_transfer_config=kv_transfer_config,
373+
distributed_executor_backend=distributed_executor_backend,
407374
)
408375
remote_prefill_opts = {
409376
"do_remote_decode": True,

tests/v1/executor/test_multiproc_executor.py renamed to tests/v1/kv_connector/unit/test_output_aggreagator.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import threading
4-
from collections import defaultdict
53
from concurrent.futures import Future
64
from typing import Optional
75

8-
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
6+
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
97
from vllm.v1.outputs import ModelRunnerOutput
108

119

12-
class DummyMultiprocExecutor(MultiprocExecutor):
13-
14-
def __init__(self, output_rank, world_size):
15-
# Manually initialize minimal required fields
16-
self.output_rank = output_rank
17-
self.world_size = world_size
18-
self._send_remaining_count = defaultdict[str,
19-
int](lambda: self.world_size)
20-
self._recv_remaining_count = defaultdict[str,
21-
int](lambda: self.world_size)
22-
self.io_thread_pool = None
23-
self.shutdown_event = threading.Event()
24-
25-
2610
class DummyModelRunnerOutput(ModelRunnerOutput):
2711

2812
def __init__(self,
@@ -33,14 +17,14 @@ def __init__(self,
3317

3418

3519
def test_aggregate_workers_output():
36-
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
20+
aggregator = KVOutputAggregator(world_size=2)
3721

3822
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
3923
finished_recving={'req2'})
4024
output2 = DummyModelRunnerOutput(finished_sending=None,
4125
finished_recving=None)
4226

43-
aggregated = executor._aggregate_workers_output([output1, output2])
27+
aggregated = aggregator.aggregate([output1, output2])
4428

4529
assert aggregated is output1
4630
assert aggregated.finished_sending is None
@@ -51,7 +35,7 @@ def test_aggregate_workers_output():
5135
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
5236
finished_recving=None)
5337

54-
aggregated = executor._aggregate_workers_output([output1, output2])
38+
aggregated = aggregator.aggregate([output1, output2])
5539

5640
assert aggregated is output1
5741
assert aggregated.finished_sending == {'req1'}
@@ -62,20 +46,19 @@ def test_aggregate_workers_output():
6246
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
6347
finished_recving={'req2'})
6448

65-
aggregated = executor._aggregate_workers_output([output1, output2])
49+
aggregated = aggregator.aggregate([output1, output2])
6650

6751
assert aggregated is output1
6852
assert aggregated.finished_sending is None
6953
assert aggregated.finished_recving == {'req2'}
7054

7155

7256
def test_async_aggregate_workers_output():
73-
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
57+
aggregator = KVOutputAggregator(world_size=2)
7458

7559
future1: Future[DummyModelRunnerOutput] = Future()
7660
future2: Future[DummyModelRunnerOutput] = Future()
77-
result_future = executor._async_aggregate_workers_output(
78-
[future1, future2])
61+
result_future = aggregator.async_aggregate([future1, future2])
7962

8063
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
8164
finished_recving={'req2'})
@@ -92,8 +75,7 @@ def test_async_aggregate_workers_output():
9275

9376
future1 = Future()
9477
future2 = Future()
95-
result_future = executor._async_aggregate_workers_output(
96-
[future1, future2])
78+
result_future = aggregator.async_aggregate([future1, future2])
9779

9880
output1 = DummyModelRunnerOutput(finished_sending=None,
9981
finished_recving=None)
@@ -110,8 +92,7 @@ def test_async_aggregate_workers_output():
11092

11193
future1 = Future()
11294
future2 = Future()
113-
result_future = executor._async_aggregate_workers_output(
114-
[future1, future2])
95+
result_future = aggregator.async_aggregate([future1, future2])
11596

11697
output1 = DummyModelRunnerOutput(finished_sending=None,
11798
finished_recving=None)

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
"""
44
KV cache helper for store.
55
"""
6+
from collections import defaultdict
7+
from collections.abc import Sequence
8+
from concurrent.futures import CancelledError, Future
9+
from typing import Optional, cast
10+
611
import torch
712

813
import vllm.envs as envs
914
from vllm import _custom_ops as ops
1015
from vllm.config import VllmConfig, get_current_vllm_config
1116
from vllm.logger import init_logger
17+
from vllm.v1.outputs import ModelRunnerOutput
1218

1319
logger = init_logger(__name__)
1420

@@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
107113
"layout to HND for better xfer performance.")
108114
return "HND"
109115
return "NHD"
116+
117+
118+
class KVOutputAggregator:
119+
"""Utility class to aggregate the output of all workers into a single
120+
output corresponding to Rank 0 for scheduler."""
121+
122+
def __init__(self, world_size: int):
123+
# Complete transfer tracker. Used by to track finished requests
124+
# [req_id -> n_finished_workers]
125+
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
126+
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
127+
128+
def aggregate(self,
129+
outputs: list[ModelRunnerOutput],
130+
output_rank: int = 0) -> ModelRunnerOutput:
131+
# aggregate finished_sending, finished_recving from all workers
132+
133+
def update_finished_set(req_ids: Optional[set[str]],
134+
remaining_count_dict: dict[str, int],
135+
finished_set: set[str]) -> None:
136+
for req_id in req_ids or ():
137+
new_count = remaining_count_dict[req_id] - 1
138+
if new_count == 0:
139+
finished_set.add(req_id)
140+
del remaining_count_dict[req_id]
141+
else:
142+
remaining_count_dict[req_id] = new_count
143+
144+
finished_sending = set[str]()
145+
finished_recving = set[str]()
146+
for output in outputs:
147+
update_finished_set(output.finished_sending,
148+
self._send_remaining_count, finished_sending)
149+
update_finished_set(output.finished_recving,
150+
self._recv_remaining_count, finished_recving)
151+
152+
# select output of the worker specified by output_rank
153+
output = outputs[output_rank]
154+
155+
# set the aggregated finished_sending / finished_recving
156+
# if output.finished_sending/recving is not empty, but the other ranks
157+
# still have unfinished send/recv, we want to set the aggregated
158+
# finished_sending/recving to None until all ranks have finished
159+
# send/recv
160+
output.finished_sending = finished_sending if finished_sending else None
161+
output.finished_recving = finished_recving if finished_recving else None
162+
163+
return output
164+
165+
def async_aggregate(self,
166+
output_futures: Sequence[Future[ModelRunnerOutput]],
167+
output_rank: int = 0) -> Future[ModelRunnerOutput]:
168+
"""Takes a list of futures and returns a single future which resolves
169+
to the respective list of outputs."""
170+
result_future: Future[ModelRunnerOutput] = Future()
171+
172+
outputs: list[Optional[ModelRunnerOutput]] = [None
173+
] * len(output_futures)
174+
175+
def make_callback(idx):
176+
177+
def callback(fut):
178+
if result_future.done():
179+
return
180+
181+
try:
182+
outputs[idx] = fut.result()
183+
except CancelledError:
184+
result_future.cancel()
185+
except Exception as e:
186+
result_future.set_exception(e)
187+
188+
# this check assumes io_thread_pool uses a single thread
189+
if all(outputs):
190+
result_future.set_result(
191+
self.aggregate(cast(list[ModelRunnerOutput], outputs),
192+
output_rank))
193+
194+
return callback
195+
196+
for i, output_future in enumerate(output_futures):
197+
output_future.add_done_callback(make_callback(i))
198+
199+
return result_future

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def get_finished(
194194
"""
195195
Notifies worker-side connector ids of requests that have
196196
finished generating tokens on the worker.
197-
The scheduler process (via the MultiprocExecutor) will use this output
197+
The scheduler process (via the Executors) will use this output
198198
to track which workers are done.
199199
200200
Returns:

vllm/mocks/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)