Skip to content

Commit 521eeab

Browse files
committed
support torchrun dp
Signed-off-by: Lu Fang <[email protected]>
1 parent c4afdb6 commit 521eeab

File tree

4 files changed

+98
-3
lines changed

4 files changed

+98
-3
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
experimental support for data-parallel inference with torchrun
5+
Note the data load balancing and distribution is done out of the vllm engine,
6+
no internal lb supported in external_launcher mode.
7+
"""
8+
9+
from vllm import LLM, SamplingParams
10+
11+
# Create prompts, the same across all ranks
12+
prompts = [
13+
"Hello, my name is",
14+
"The president of the United States is",
15+
"The capital of France is",
16+
"The future of AI is",
17+
] * 50
18+
19+
# Create sampling parameters, the same across all ranks
20+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
21+
22+
# Use `distributed_executor_backend="external_launcher"` so that
23+
# this llm engine/instance only creates one worker.
24+
# it is important to set an explicit seed to make sure that
25+
# all ranks have the same random seed, so that sampling can be
26+
# deterministic across ranks.
27+
llm = LLM(
28+
model="/data/local/models/oss/qwen1.5_2.7B_moe_chat",
29+
tensor_parallel_size=2,
30+
data_parallel_size=4,
31+
pipeline_parallel_size=1,
32+
enable_expert_parallel=True,
33+
distributed_executor_backend="external_launcher",
34+
max_model_len=32768,
35+
# FIXME: with torch.compile, the torchrun processes do not exit properly
36+
enforce_eager=True,
37+
seed=1,
38+
)
39+
40+
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
41+
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
42+
43+
prompts = [
44+
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
45+
]
46+
47+
outputs = llm.generate(prompts, sampling_params)
48+
49+
50+
# all ranks will have the same outputs
51+
print("-" * 50)
52+
for output in outputs:
53+
prompt = output.prompt
54+
generated_text = output.outputs[0].text
55+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
56+
print("-" * 50)
57+
58+
"""
59+
Further tips:
60+
61+
1. to communicate control messages across all ranks, use the cpu group,
62+
a PyTorch ProcessGroup with GLOO backend.
63+
64+
```python
65+
from vllm.distributed.parallel_state import get_world_group
66+
cpu_group = get_world_group().cpu_group
67+
torch_rank = dist.get_rank(group=cpu_group)
68+
if torch_rank == 0:
69+
# do something for rank 0, e.g. saving the results to disk.
70+
```
71+
72+
2. to communicate data across all ranks, use the model's device group,
73+
a PyTorch ProcessGroup with NCCL backend.
74+
```python
75+
from vllm.distributed.parallel_state import get_world_group
76+
device_group = get_world_group().device_group
77+
```
78+
79+
3. to access the model directly in every rank, use the following code:
80+
```python
81+
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
82+
```
83+
"""

vllm/config/parallel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def __post_init__(self) -> None:
313313
# Continue with the rest of the initialization
314314
self.world_size = self.pipeline_parallel_size * \
315315
self.tensor_parallel_size
316+
317+
if self.distributed_executor_backend == "external_launcher":
318+
logger.info("Using external launcher for distributed inference.")
319+
self.world_size *= self.data_parallel_size
316320

317321
if self.data_parallel_size_local > self.data_parallel_size:
318322
raise ValueError(
@@ -321,6 +325,12 @@ def __post_init__(self) -> None:
321325

322326
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
323327
# Data parallel was specified in the engine args.
328+
if self.distributed_executor_backend == "external_launcher":
329+
# For external launcher, we need to set the data parallel rank automatically
330+
# We assume DP is the first dimension of parallelism.
331+
import os
332+
self.data_parallel_rank = int(os.environ["RANK"]) // (self.world_size // self.data_parallel_size)
333+
logger.debug(f"Setting data_parallel_rank to {self.data_parallel_rank} automatically.")
324334
if not self._data_parallel_master_port_list:
325335
self._data_parallel_master_port_list = get_open_ports_list(5)
326336
self.data_parallel_master_port = \

vllm/distributed/parallel_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,8 @@ def init_distributed_environment(world_size: int = -1,
991991
distributed_init_method, backend)
992992
from vllm.config import get_current_vllm_config
993993
config = get_current_vllm_config()
994-
if config is not None and config.parallel_config.data_parallel_size > 1:
994+
if config is not None and config.parallel_config.data_parallel_size > 1 \
995+
and config.parallel_config.distributed_executor_backend != "external_launcher":
995996
parallel_config = config.parallel_config
996997
# adjust to take into account data parallelism
997998
# offset the rank by the data parallel rank

vllm/v1/engine/llm_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ def __init__(
7878
# important: init dp group before init the engine_core
7979
# In the decoupled engine case this is handled in EngineCoreProc.
8080
parallel_config = vllm_config.parallel_config
81-
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
82-
self.dp_group = parallel_config.stateless_init_dp_group()
81+
if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
82+
and self.vllm_config.parallel_config.distributed_executor_backend != "external_launcher":
83+
self.dp_group = parallel_config.stateless_init_dp_group()
8384
else:
8485
self.dp_group = None
8586
self.should_execute_dummy_batch = False

0 commit comments

Comments
 (0)