Skip to content

Commit 15dac21

Browse files
authored
[V1] AsyncLLM data parallel (#13923)
Signed-off-by: Nick Hill <[email protected]>
1 parent 112b3e5 commit 15dac21

File tree

18 files changed

+726
-160
lines changed

18 files changed

+726
-160
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,14 @@ steps:
135135
- examples/offline_inference/rlhf.py
136136
- examples/offline_inference/rlhf_colocate.py
137137
- tests/examples/offline_inference/data_parallel.py
138+
- tests/v1/test_async_llm_dp.py
138139
commands:
139140
# test with tp=2 and external_dp=2
140141
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
141142
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
142143
# test with internal dp
143144
- python3 ../examples/offline_inference/data_parallel.py
145+
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
144146
- pytest -v -s distributed/test_utils.py
145147
- pytest -v -s compile/test_basic_correctness.py
146148
- pytest -v -s distributed/test_pynccl.py
@@ -514,7 +516,10 @@ steps:
514516
- vllm/worker/worker.py
515517
- vllm/worker/model_runner.py
516518
- entrypoints/llm/test_collective_rpc.py
519+
- tests/v1/test_async_llm_dp.py
520+
- vllm/v1/engine/
517521
commands:
522+
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
518523
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
519524
- pytest -v -s ./compile/test_basic_correctness.py
520525
- pytest -v -s ./compile/test_wrapper.py

examples/offline_inference/data_parallel.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
--master-port=13345
2929
"""
3030
import os
31+
from time import sleep
3132

3233
from vllm import LLM, SamplingParams
3334
from vllm.utils import get_open_port
@@ -36,14 +37,13 @@
3637
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
3738
dp_master_port, GPUs_per_dp_rank):
3839
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
40+
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
3941
os.environ["VLLM_DP_SIZE"] = str(dp_size)
4042
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
4143
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
42-
# set devices for each dp_rank
43-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
44-
str(i)
45-
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
46-
GPUs_per_dp_rank))
44+
45+
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
46+
# engine processes.
4747

4848
# Sample prompts.
4949
prompts = [
@@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
9090
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
9191
f"Generated text: {generated_text!r}")
9292

93+
# Give engines time to pause their processing loops before exiting.
94+
sleep(1)
95+
9396

9497
if __name__ == "__main__":
9598
import argparse
@@ -152,8 +155,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
152155
procs.append(proc)
153156
exit_code = 0
154157
for proc in procs:
155-
proc.join()
156-
if proc.exitcode:
158+
proc.join(timeout=300)
159+
if proc.exitcode is None:
160+
print(f"Killing process {proc.pid} that "
161+
f"didn't stop within 5 minutes.")
162+
proc.kill()
163+
exit_code = 1
164+
elif proc.exitcode:
157165
exit_code = proc.exitcode
158166

159167
exit(exit_code)

tests/v1/engine/test_engine_core_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
167167

168168
core_client: SyncMPClient = client
169169

170-
result = core_client._call_utility("echo", "testarg")
170+
result = core_client.call_utility("echo", "testarg")
171171
assert result == "testarg"
172172

173173
with pytest.raises(Exception) as e_info:
174-
core_client._call_utility("echo", None, "help!")
174+
core_client.call_utility("echo", None, "help!")
175175

176176
assert str(e_info.value) == "Call to echo method failed: help!"
177177

@@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
238238

239239
core_client: AsyncMPClient = client
240240

241-
result = await core_client._call_utility_async("echo", "testarg")
241+
result = await core_client.call_utility_async("echo", "testarg")
242242
assert result == "testarg"
243243

244244
with pytest.raises(Exception) as e_info:
245-
await core_client._call_utility_async("echo", None, "help!")
245+
await core_client.call_utility_async("echo", None, "help!")
246246

247247
assert str(e_info.value) == "Call to echo method failed: help!"

tests/v1/test_async_llm_dp.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import asyncio
4+
import os
5+
from contextlib import ExitStack
6+
from typing import Optional
7+
8+
import pytest
9+
10+
from vllm import SamplingParams
11+
from vllm.engine.arg_utils import AsyncEngineArgs
12+
from vllm.inputs import PromptType
13+
from vllm.platforms import current_platform
14+
from vllm.sampling_params import RequestOutputKind
15+
from vllm.v1.engine.async_llm import AsyncLLM
16+
from vllm.v1.engine.core_client import DPAsyncMPClient
17+
18+
engine_args = AsyncEngineArgs(
19+
model="ibm-research/PowerMoE-3b",
20+
enforce_eager=True,
21+
disable_log_requests=True,
22+
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
23+
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
24+
)
25+
26+
if not current_platform.supports_v1(engine_args.create_model_config()):
27+
pytest.skip(reason="Requires V1-supporting platform.",
28+
allow_module_level=True)
29+
30+
31+
async def generate(engine: AsyncLLM,
32+
request_id: str,
33+
prompt: PromptType,
34+
output_kind: RequestOutputKind,
35+
max_tokens: int,
36+
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
37+
# Ensure generate doesn't complete too fast for cancellation test.
38+
await asyncio.sleep(0.2)
39+
40+
count = 0
41+
sampling_params = SamplingParams(max_tokens=max_tokens,
42+
ignore_eos=True,
43+
output_kind=output_kind,
44+
temperature=0,
45+
prompt_logprobs=prompt_logprobs)
46+
async for out in engine.generate(request_id=request_id,
47+
prompt=prompt,
48+
sampling_params=sampling_params):
49+
50+
num_tokens = len(out.outputs[0].token_ids)
51+
if output_kind == RequestOutputKind.DELTA:
52+
count += num_tokens
53+
else:
54+
count = num_tokens
55+
56+
await asyncio.sleep(0.)
57+
58+
return count, request_id
59+
60+
61+
@pytest.mark.parametrize(
62+
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
63+
@pytest.mark.asyncio
64+
async def test_load(output_kind: RequestOutputKind):
65+
66+
with ExitStack() as after:
67+
68+
prompt = "This is a test of data parallel"
69+
70+
engine = AsyncLLM.from_engine_args(engine_args)
71+
after.callback(engine.shutdown)
72+
73+
NUM_REQUESTS = 100
74+
NUM_EXPECTED_TOKENS = 10
75+
76+
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
77+
78+
# Create concurrent requests.
79+
tasks = []
80+
for request_id in request_ids:
81+
tasks.append(
82+
asyncio.create_task(
83+
generate(engine, request_id, prompt, output_kind,
84+
NUM_EXPECTED_TOKENS)))
85+
86+
# Confirm that we got all the EXPECTED tokens from the requests.
87+
done, pending = await asyncio.wait(tasks,
88+
return_when=asyncio.FIRST_EXCEPTION)
89+
for task in pending:
90+
task.cancel()
91+
for task in done:
92+
num_generated_tokens, request_id = await task
93+
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
94+
f"{request_id} generated {num_generated_tokens} but "
95+
f"expected {NUM_EXPECTED_TOKENS}")
96+
97+
assert not engine.output_processor.has_unfinished_requests()
98+
99+
# testing internals here which may break
100+
core_client: DPAsyncMPClient = engine.engine_core
101+
# the engines only synchronize stopping every N steps so
102+
# allow a small amount of time here.
103+
for _ in range(10):
104+
if core_client.num_engines_running == 0:
105+
break
106+
await asyncio.sleep(0.5)
107+
108+
assert core_client.num_engines_running == 0
109+
assert not core_client.reqs_in_flight

vllm/config.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from vllm.transformers_utils.s3_utils import S3Model
4141
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
4242
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
43-
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
43+
get_cpu_memory, get_open_port, random_uuid,
44+
resolve_obj_by_qualname)
4445

4546
if TYPE_CHECKING:
4647
from ray.util.placement_group import PlacementGroup
@@ -1389,6 +1390,8 @@ class ParallelConfig:
13891390
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
13901391
data_parallel_size: int = 1 # Number of data parallel groups.
13911392
data_parallel_rank: int = 0 # Rank of the data parallel group.
1393+
# Local rank of the data parallel group, defaults to global rank.
1394+
data_parallel_rank_local: Optional[int] = None
13921395
# IP of the data parallel master.
13931396
data_parallel_master_ip: str = "127.0.0.1"
13941397
data_parallel_master_port: int = 29500 # Port of the data parallel master.
@@ -1493,10 +1496,18 @@ def __post_init__(self) -> None:
14931496
self.world_size = self.pipeline_parallel_size * \
14941497
self.tensor_parallel_size
14951498

1496-
self.data_parallel_size = envs.VLLM_DP_SIZE
1497-
self.data_parallel_rank = envs.VLLM_DP_RANK
1498-
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
1499-
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
1499+
if self.data_parallel_size > 1:
1500+
# Data parallel was specified in the engine args.
1501+
self.data_parallel_master_port = get_open_port()
1502+
# TODO multi-node
1503+
else:
1504+
# Otherwise fall back to env vars (e.g. for offline SPMD case).
1505+
self.data_parallel_size = envs.VLLM_DP_SIZE
1506+
self.data_parallel_rank = envs.VLLM_DP_RANK
1507+
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
1508+
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
1509+
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
1510+
15001511
self.world_size_across_dp = self.world_size * self.data_parallel_size
15011512

15021513
if self.distributed_executor_backend == "external_launcher":

vllm/distributed/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.distributed import ProcessGroup, TCPStore
1616
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
1717
_get_default_timeout,
18+
_shutdown_backend,
19+
_unregister_process_group,
1820
is_nccl_available)
1921
from torch.distributed.rendezvous import rendezvous
2022

@@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group(
333335
pg._register_backend(device, backend_type, backend_class)
334336

335337
return pg
338+
339+
340+
def stateless_destroy_torch_distributed_process_group(
341+
pg: ProcessGroup) -> None:
342+
"""
343+
Destroy ProcessGroup returned by
344+
stateless_init_torch_distributed_process_group().
345+
"""
346+
_shutdown_backend(pg)
347+
_unregister_process_group(pg.group_name)

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class EngineArgs:
114114
# number of P/D disaggregation (or other disaggregation) workers
115115
pipeline_parallel_size: int = 1
116116
tensor_parallel_size: int = 1
117+
data_parallel_size: int = 1
117118
enable_expert_parallel: bool = False
118119
max_parallel_loading_workers: Optional[int] = None
119120
block_size: Optional[int] = None
@@ -442,6 +443,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
442443
type=int,
443444
default=EngineArgs.tensor_parallel_size,
444445
help='Number of tensor parallel replicas.')
446+
parser.add_argument('--data-parallel-size',
447+
'-dp',
448+
type=int,
449+
default=EngineArgs.data_parallel_size,
450+
help='Number of data parallel replicas. '
451+
'MoE layers will be sharded according to the '
452+
'product of the tensor-parallel-size and '
453+
'data-parallel-size.')
445454
parser.add_argument(
446455
'--enable-expert-parallel',
447456
action='store_true',
@@ -1359,6 +1368,7 @@ def create_engine_config(
13591368
parallel_config = ParallelConfig(
13601369
pipeline_parallel_size=self.pipeline_parallel_size,
13611370
tensor_parallel_size=self.tensor_parallel_size,
1371+
data_parallel_size=self.data_parallel_size,
13621372
enable_expert_parallel=self.enable_expert_parallel,
13631373
max_parallel_loading_workers=self.max_parallel_loading_workers,
13641374
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hashlib
44
import os
5+
import sys
56
import tempfile
67
from typing import TYPE_CHECKING, Any, Callable, Optional
78

@@ -95,6 +96,7 @@
9596
VLLM_CUDART_SO_PATH: Optional[str] = None
9697
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
9798
VLLM_DP_RANK: int = 0
99+
VLLM_DP_RANK_LOCAL: int = -1
98100
VLLM_DP_SIZE: int = 1
99101
VLLM_DP_MASTER_IP: str = ""
100102
VLLM_DP_MASTER_PORT: int = 0
@@ -625,6 +627,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
625627
"VLLM_DP_RANK":
626628
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
627629

630+
# Rank of the process in the data parallel setting.
631+
# Defaults to VLLM_DP_RANK when not set.
632+
"VLLM_DP_RANK_LOCAL":
633+
lambda: int(
634+
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
635+
628636
# World size of the data parallel setting
629637
"VLLM_DP_SIZE":
630638
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),

vllm/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def get_open_port() -> int:
578578
dp_port = envs.VLLM_DP_MASTER_PORT
579579
while True:
580580
port = _get_open_port()
581-
if port >= dp_port and port < dp_port + 10:
581+
if dp_port <= port < dp_port + 10:
582582
continue
583583
return port
584584
return _get_open_port()
@@ -2176,19 +2176,23 @@ def make_zmq_socket(
21762176
if socket_type == zmq.constants.PULL:
21772177
socket.setsockopt(zmq.constants.RCVHWM, 0)
21782178
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
2179-
socket.connect(path)
2179+
socket.bind(path)
21802180
elif socket_type == zmq.constants.PUSH:
21812181
socket.setsockopt(zmq.constants.SNDHWM, 0)
21822182
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
2183-
socket.bind(path)
2183+
socket.connect(path)
21842184
else:
21852185
raise ValueError(f"Unknown Socket Type: {socket_type}")
21862186

21872187
return socket
21882188

21892189

21902190
@contextlib.contextmanager
2191-
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
2191+
def zmq_socket_ctx(
2192+
path: str,
2193+
socket_type: Any,
2194+
linger: int = 0,
2195+
) -> Iterator[zmq.Socket]:
21922196
"""Context manager for a ZMQ socket"""
21932197

21942198
ctx = zmq.Context() # type: ignore[attr-defined]
@@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
21992203
logger.debug("Got Keyboard Interrupt.")
22002204

22012205
finally:
2202-
ctx.destroy(linger=0)
2206+
ctx.destroy(linger=linger)
22032207

22042208

22052209
def is_in_ray_actor():

0 commit comments

Comments
 (0)