Skip to content

Commit 4acaefc

Browse files
luccafongzhuohan123
authored andcommitted
[DP] support torchrun external launcher with Data Parallelism (vllm-project#24899)
Signed-off-by: Lu Fang <[email protected]> Signed-off-by: Zhuohan Li <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent 3811a4f commit 4acaefc

File tree

6 files changed

+202
-7
lines changed

6 files changed

+202
-7
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,18 @@ steps:
165165
- tests/v1/test_hybrid_lb_dp.py
166166
- tests/v1/engine/test_engine_core_client.py
167167
commands:
168-
# test with tp=2 and external_dp=2
168+
# test with torchrun tp=2 and external_dp=2
169169
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
170-
# test with tp=2 and pp=2
170+
# test with torchrun tp=2 and pp=2
171171
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
172+
# test with torchrun tp=4 and dp=1
173+
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
174+
# test with torchrun tp=2, pp=2 and dp=1
175+
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
176+
# test with torchrun tp=1 and dp=4 with ep
177+
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
178+
# test with torchrun tp=2 and dp=2 with ep
179+
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
172180
# test with internal dp
173181
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
174182
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
experimental support for data-parallel inference with torchrun
5+
Note the data load balancing and distribution is done out of the vllm engine,
6+
no internal lb supported in external_launcher mode.
7+
"""
8+
9+
from vllm import LLM, SamplingParams
10+
11+
# Create prompts, the same across all ranks
12+
prompts = [
13+
"Hello, my name is",
14+
"The president of the United States is",
15+
"The capital of France is",
16+
"The future of AI is",
17+
] * 50
18+
19+
# Create sampling parameters, the same across all ranks
20+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
21+
22+
# Use `distributed_executor_backend="external_launcher"` so that
23+
# this llm engine/instance only creates one worker.
24+
# it is important to set an explicit seed to make sure that
25+
# all ranks have the same random seed, so that sampling can be
26+
# deterministic across ranks.
27+
llm = LLM(
28+
model="microsoft/Phi-mini-MoE-instruct",
29+
tensor_parallel_size=1,
30+
data_parallel_size=2,
31+
pipeline_parallel_size=1,
32+
enable_expert_parallel=False,
33+
distributed_executor_backend="external_launcher",
34+
max_model_len=4096,
35+
gpu_memory_utilization=0.6,
36+
seed=1,
37+
)
38+
39+
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
40+
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
41+
42+
prompts = [
43+
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
44+
]
45+
46+
outputs = llm.generate(prompts, sampling_params)
47+
48+
49+
# all ranks will have the same outputs
50+
print("-" * 50)
51+
for output in outputs:
52+
prompt = output.prompt
53+
generated_text = output.outputs[0].text
54+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
55+
print("-" * 50)
56+
"""
57+
Further tips:
58+
59+
1. to communicate control messages across all ranks, use the cpu group,
60+
a PyTorch ProcessGroup with GLOO backend.
61+
62+
```python
63+
from vllm.distributed.parallel_state import get_world_group
64+
cpu_group = get_world_group().cpu_group
65+
torch_rank = dist.get_rank(group=cpu_group)
66+
if torch_rank == 0:
67+
# do something for rank 0, e.g. saving the results to disk.
68+
```
69+
70+
2. to communicate data across all ranks, use the model's device group,
71+
a PyTorch ProcessGroup with NCCL backend.
72+
```python
73+
from vllm.distributed.parallel_state import get_world_group
74+
device_group = get_world_group().device_group
75+
```
76+
77+
3. to access the model directly in every rank, use the following code:
78+
```python
79+
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
80+
```
81+
"""
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# unit test for `examples/offline_inference/torchrun_example.py`
5+
import os
6+
import random
7+
8+
import torch.distributed as dist
9+
10+
from vllm import LLM, SamplingParams
11+
from vllm.distributed.parallel_state import get_tp_group, get_world_group
12+
13+
dist.init_process_group(backend="gloo")
14+
15+
# Create prompts
16+
prompts = [
17+
"Hello, my name is",
18+
"The president of the United States is",
19+
"The capital of France is",
20+
"The future of AI is",
21+
] * 10
22+
dp_size = int(os.getenv("DP_SIZE", "1"))
23+
dp_rank = int(os.getenv("DP_RANK", "0"))
24+
25+
if dp_size > 1:
26+
# distribute the prompts across the data parallel ranks
27+
prompts = [
28+
prompt for idx, prompt in enumerate(prompts)
29+
if idx % dp_size == dp_rank
30+
]
31+
32+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
33+
34+
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
35+
# to test if all ranks agree on the same kv cache configuration.
36+
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
37+
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
38+
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
39+
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
40+
distributed_executor_backend="external_launcher",
41+
gpu_memory_utilization=random.uniform(0.7, 0.9),
42+
swap_space=random.randint(1, 4),
43+
seed=0)
44+
45+
outputs = llm.generate(prompts, sampling_params)
46+
47+
group = get_world_group() if dp_size == 1 else get_tp_group()
48+
cpu_group = group.cpu_group
49+
group_rank = dist.get_rank(group=cpu_group)
50+
51+
52+
def test_consistent_across_ranks(obj):
53+
if group_rank == 0:
54+
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
55+
else:
56+
container = [None]
57+
dist.broadcast_object_list(container,
58+
src=group.ranks[0],
59+
group=cpu_group)
60+
assert container[0] == obj
61+
62+
63+
test_consistent_across_ranks(
64+
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
65+
test_consistent_across_ranks(
66+
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
67+
68+
# make sure we can access the model parameters from the calling process
69+
# of the `LLM` instance.
70+
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
71+
model.parameters())
72+
test_consistent_across_ranks(len(params))
73+
74+
# all ranks should have the same outputs
75+
for output in outputs:
76+
prompt = output.prompt
77+
generated_text = output.outputs[0].text
78+
test_consistent_across_ranks(prompt)
79+
test_consistent_across_ranks(generated_text)
80+
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
81+
f"Generated text: {generated_text!r}")

vllm/config/parallel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5+
import os
56
from dataclasses import field
67
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
78

@@ -351,13 +352,24 @@ def __post_init__(self) -> None:
351352
self.world_size = self.pipeline_parallel_size * \
352353
self.tensor_parallel_size
353354

355+
if self.distributed_executor_backend == "external_launcher":
356+
logger.info("Using external launcher for distributed inference.")
357+
self.world_size *= self.data_parallel_size
358+
354359
if self.data_parallel_size_local > self.data_parallel_size:
355360
raise ValueError(
356361
f"data_parallel_size_local ({self.data_parallel_size_local}) "
357362
f"must be <= data_parallel_size ({self.data_parallel_size})")
358363

359364
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
360365
# Data parallel was specified in the engine args.
366+
if self.distributed_executor_backend == "external_launcher":
367+
# For external launcher,
368+
# we need to set the data parallel rank automatically
369+
self.data_parallel_rank = int(os.environ["RANK"]) \
370+
// (self.world_size // self.data_parallel_size)
371+
logger.info("Set data_parallel_rank to %d automatically.",
372+
self.data_parallel_rank)
361373
if not self._data_parallel_master_port_list:
362374
self._data_parallel_master_port_list = get_open_ports_list(5)
363375
self.data_parallel_master_port = \
@@ -380,7 +392,6 @@ def __post_init__(self) -> None:
380392
"be set when data_parallel_size > 1")
381393

382394
if self.distributed_executor_backend == "external_launcher":
383-
import os
384395
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
385396
logger.info("Disabling V1 multiprocessing for external launcher.")
386397

vllm/distributed/parallel_state.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1,
10321032
distributed_init_method, backend)
10331033
from vllm.config import get_current_vllm_config
10341034
config = get_current_vllm_config()
1035-
if config is not None and config.parallel_config.data_parallel_size > 1:
1035+
if config is not None and config.parallel_config.data_parallel_size > 1 \
1036+
and config.parallel_config.distributed_executor_backend \
1037+
!= "external_launcher":
10361038
parallel_config = config.parallel_config
10371039
# adjust to take into account data parallelism
10381040
# offset the rank by the data parallel rank

vllm/v1/engine/llm_engine.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import vllm.envs as envs
1212
from vllm.config import ParallelConfig, VllmConfig
1313
from vllm.distributed import stateless_destroy_torch_distributed_process_group
14+
from vllm.distributed.parallel_state import get_dp_group
1415
from vllm.engine.arg_utils import EngineArgs
1516
from vllm.inputs import PromptType
1617
from vllm.logger import init_logger
@@ -77,10 +78,15 @@ def __init__(
7778
if self.log_stats:
7879
self.stat_logger = PrometheusStatLogger(vllm_config)
7980

81+
executor_backend = (
82+
self.vllm_config.parallel_config.distributed_executor_backend)
83+
parallel_config = vllm_config.parallel_config
84+
self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and
85+
executor_backend == "external_launcher")
8086
# important: init dp group before init the engine_core
8187
# In the decoupled engine case this is handled in EngineCoreProc.
82-
parallel_config = vllm_config.parallel_config
83-
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
88+
if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
89+
and not self.external_launcher_dp:
8490
self.dp_group = parallel_config.stateless_init_dp_group()
8591
else:
8692
self.dp_group = None
@@ -120,6 +126,11 @@ def __init__(
120126
# for v0 compatibility
121127
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
122128

129+
if self.external_launcher_dp:
130+
# If we use DP in external launcher mode, we reuse the
131+
# existing DP group used for data communication.
132+
self.dp_group = get_dp_group().cpu_group
133+
123134
# Don't keep the dummy data in memory
124135
self.reset_mm_cache()
125136

@@ -331,5 +342,6 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
331342
return self.collective_rpc("apply_model", args=(func, ))
332343

333344
def __del__(self):
334-
if dp_group := getattr(self, "dp_group", None):
345+
if dp_group := getattr(self, "dp_group",
346+
None) and not self.external_launcher_dp:
335347
stateless_destroy_torch_distributed_process_group(dp_group)

0 commit comments

Comments
 (0)