Skip to content

Commit e3636c7

Browse files
authored
[0.9.1] Enable external distributed dp deployments in vllm ascend(0.9.1 only) (#2109)
### What this PR does / why we need it? The vllm's dp strategy follows the classic master-slave implementation for both scale-out and scale-up scenario. However, this kind of implementation usually bring more pressure to the master node than the others. Which may cause the unbalance host overhead issue for different woker process, and potentially harm the performance. Except for the master-slave structure, the dplb is implemented by by lots of independent process, the isolated memory space also makes the total balance on the engine workload become hard to achieve. In this PR, we break the chain of master-slave struture and seperate the dp instance to different vllm engine instance which owns its private ip and port repectively. This implementation evens host pressure for the different worker process, and brings untrival performance boost over the former one. Besides, the load balance case can be achieved by proxy rather than independent process inside the engine, which give more flexibility to the users. ### Does this PR introduce _any_ user-facing change? Yes, this implementation will have distinct launch script compared with vllm's original one, the usage tutorial and example scripts are placed in the `external_online_dp` folder. ### How was this patch tested? --------- Signed-off-by: ganyi <[email protected]>
1 parent a704967 commit e3636c7

File tree

9 files changed

+332
-2
lines changed

9 files changed

+332
-2
lines changed

examples/external_online_dp/README.md

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import multiprocessing
2+
import os
3+
import sys
4+
5+
dp_size = 32
6+
dp_size_local = 16
7+
dp_rank_start = 0
8+
dp_ip = "your_dp_ip_here"
9+
dp_port = "your_dp_port_here"
10+
engine_port = 9000
11+
template_path = "./run_dp_template.sh"
12+
if not os.path.exists(template_path):
13+
print(f"Template file {template_path} does not exist.")
14+
sys.exit(1)
15+
16+
17+
def run_command(dp_rank_local, dp_rank, engine_port_):
18+
command = f"bash ./run_dp_template.sh {dp_size} {dp_ip} {dp_port} {dp_rank_local} {dp_rank} {engine_port_} {dp_size_local}"
19+
os.system(command)
20+
21+
22+
processes = []
23+
for i in range(dp_size_local):
24+
dp_rank = dp_rank_start + i
25+
dp_rank_local = i
26+
engine_port_ = engine_port + i
27+
process = multiprocessing.Process(target=run_command,
28+
args=(dp_rank_local, dp_rank,
29+
engine_port_))
30+
processes.append(process)
31+
process.start()
32+
33+
for process in processes:
34+
process.join()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export HCCL_IF_IP=your_ip_here
2+
export GLOO_SOCKET_IFNAME="enp48s3u1u1"
3+
export TP_SOCKET_IFNAME="enp48s3u1u1"
4+
export HCCL_SOCKET_IFNAME="enp48s3u1u1"
5+
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=your_rank_table_path_here
6+
export VLLM_LOGGING_LEVEL="info"
7+
export OMP_PROC_BIND=false
8+
export OMP_NUM_THREADS=10
9+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
10+
export VLLM_DP_SIZE=$1
11+
export VLLM_DP_MASTER_IP=$2
12+
export VLLM_DP_MASTER_PORT=$3
13+
export VLLM_DP_RANK_LOCAL=$4
14+
export VLLM_DP_RANK=$5
15+
export VLLM_DP_SIZE_LOCAL=$7
16+
export HCCL_DETERMINISTIC=True
17+
export HCCL_BUFFER_SIZE=1024
18+
export TASK_QUEUE_ENABLE=1
19+
# Spawn the process inside the vllm maybe cause the circular import issue, using fork here is necessary
20+
export VLLM_WORKER_MULTIPROC_METHOD="fork"
21+
22+
23+
export VLLM_USE_V1=1
24+
25+
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
26+
27+
vllm serve model_path \
28+
--host 0.0.0.0 \
29+
--port $6 \
30+
--tensor-parallel-size 2 \
31+
--enable-expert-parallel \
32+
--seed 1024 \
33+
--served-model-name dsv3 \
34+
--max-model-len 5200 \
35+
--max-num-batched-tokens 256 \
36+
--max-num-seqs 28 \
37+
--trust-remote-code \
38+
--gpu-memory-utilization 0.9 \
39+
--quantization ascend \
40+
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
41+
--kv-transfer-config \
42+
'{"kv_connector": "LLMDataDistCMgrConnector",
43+
"kv_buffer_device": "npu",
44+
"kv_role": "kv_consumer",
45+
"kv_parallel_size": "1",
46+
"kv_port": "20001",
47+
"engine_id": "0",
48+
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
49+
}' \
50+
--additional-config \
51+
'{"ascend_scheduler_config": {"enabled": true}, "torchair_graph_config":{"enabled":true,"enable_kv_nz":false, "enable_multistream_moe":false, "graph_batch_size":[28]}, "enable_weight_nz_layout":true}`

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ def __init__(self, vllm_config: VllmConfig):
308308
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
309309
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
310310
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
311+
dp_size_local = vllm_config.parallel_config.data_parallel_size_local if not envs.VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED else envs.VLLM_DP_SIZE_LOCAL
311312
self.local_rank_on_node = get_world_group().rank % (
312-
vllm_config.parallel_config.data_parallel_size_local *
313-
vllm_config.parallel_config.tensor_parallel_size)
313+
dp_size_local * vllm_config.parallel_config.tensor_parallel_size)
314314
self.local_rank = get_world_group().local_rank
315315
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
316316
self.tp_size = vllm_config.parallel_config.tensor_parallel_size

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@
163163
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
164164
lambda: int(
165165
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", '0')),
166+
# VLLM_DP_SIZE_LOCAL: used for external data parallelism in vllm-ascend to specify the local parallel size of current node, 0.9.1 specific.
167+
"VLLM_DP_SIZE_LOCAL":
168+
lambda: int(os.getenv("VLLM_DP_SIZE_LOCAL", '0')),
169+
# VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED: used for external distributed data parallelism in vllm-ascend, 0.9.1 specific.
170+
"VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED":
171+
lambda: bool(int(os.getenv("VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED", '0'))),
166172
}
167173

168174
# end-env-vars-definition

vllm_ascend/patch/platform/patch_0_9_1/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@
1919
# patch files.
2020
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
2121
import vllm_ascend.patch.platform.patch_0_9_1.patch_cache_manager # noqa
22+
import vllm_ascend.patch.platform.patch_0_9_1.patch_configs # noqa
23+
import vllm_ascend.patch.platform.patch_0_9_1.patch_core # noqa
24+
import vllm_ascend.patch.platform.patch_0_9_1.patch_core_client # noqa
25+
import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import vllm.envs as envs
2+
from vllm.config import DistributedExecutorBackend, ParallelConfig
3+
from vllm.logger import init_logger
4+
5+
import vllm_ascend.envs as vllm_ascend_envs
6+
7+
logger = init_logger(__name__)
8+
9+
10+
def __post_init__(self: ParallelConfig) -> None:
11+
self.world_size = self.pipeline_parallel_size * \
12+
self.tensor_parallel_size
13+
14+
if self.data_parallel_size_local > self.data_parallel_size:
15+
raise ValueError(
16+
f"data_parallel_size_local ({self.data_parallel_size_local}) "
17+
f"must be <= data_parallel_size ({self.data_parallel_size})")
18+
19+
self.data_parallel_size = envs.VLLM_DP_SIZE
20+
self.data_parallel_rank = envs.VLLM_DP_RANK
21+
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
22+
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
23+
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
24+
25+
if self.distributed_executor_backend == "external_launcher":
26+
import os
27+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
28+
logger.info("Disabling V1 multiprocessing for external launcher.")
29+
30+
ray_only_devices: list[str] = []
31+
from vllm.platforms import current_platform
32+
if (current_platform.device_type in ray_only_devices
33+
and self.world_size > 1):
34+
if self.distributed_executor_backend is None:
35+
self.distributed_executor_backend = "ray"
36+
if self.distributed_executor_backend != "ray":
37+
raise ValueError(
38+
f"{current_platform.device_type.upper()} backend only "
39+
"supports Ray for distributed inference.")
40+
41+
if self.distributed_executor_backend is None and self.world_size > 1:
42+
# We use multiprocessing by default if world_size fits on the
43+
# current node and we aren't in a ray placement group.
44+
45+
from vllm.executor import ray_utils
46+
backend: DistributedExecutorBackend = "mp"
47+
ray_found = ray_utils.ray_is_available()
48+
if current_platform.is_neuron():
49+
# neuron uses single process to control multiple devices
50+
backend = "uni"
51+
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
52+
backend = "uni"
53+
elif self.data_parallel_backend == "ray":
54+
logger.info("Using ray distributed inference because "
55+
"data_parallel_backend is ray")
56+
backend = "ray"
57+
elif ray_found:
58+
if self.placement_group:
59+
backend = "ray"
60+
else:
61+
from ray import is_initialized as ray_is_initialized
62+
if ray_is_initialized():
63+
from ray.util import get_current_placement_group
64+
if get_current_placement_group():
65+
backend = "ray"
66+
self.distributed_executor_backend = backend
67+
logger.info("Defaulting to use %s for distributed inference", backend)
68+
69+
if self.distributed_executor_backend is None and self.world_size == 1:
70+
self.distributed_executor_backend = "uni"
71+
72+
self._verify_args()
73+
74+
75+
# apply this patch only if the external data parallelism is enabled
76+
if vllm_ascend_envs.VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED:
77+
ParallelConfig.__post_init__ = __post_init__ # type: ignore[attr-defined]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import os
2+
import signal
3+
from typing import Optional
4+
5+
from vllm.config import ParallelConfig, VllmConfig
6+
from vllm.logger import init_logger
7+
from vllm.transformers_utils.config import \
8+
maybe_register_config_serialize_by_value
9+
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
10+
11+
import vllm_ascend.envs as vllm_ascend_envs
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class ExternealDPEngineCoreProc(DPEngineCoreProc):
17+
18+
def __init__(self, *args, **kwargs):
19+
# Use the external data parallelism master port from envs
20+
super().__init__(*args, **kwargs)
21+
self.engines_running = True
22+
23+
def _has_global_unfinished_reqs(self, local_unfinished):
24+
return True
25+
26+
def _init_data_parallel(self, vllm_config: VllmConfig):
27+
28+
# Configure GPUs and stateless process group for data parallel.
29+
dp_rank = vllm_config.parallel_config.data_parallel_rank
30+
dp_size = vllm_config.parallel_config.data_parallel_size
31+
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
32+
33+
assert dp_size > 1
34+
assert 0 <= local_dp_rank <= dp_rank < dp_size
35+
36+
if vllm_config.kv_transfer_config is not None:
37+
# modify the engine_id and append the local_dp_rank to it to ensure
38+
# that the kv_transfer_config is unique for each DP rank.
39+
vllm_config.kv_transfer_config.engine_id = (
40+
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
41+
)
42+
logger.debug("Setting kv_transfer_config.engine_id to %s",
43+
vllm_config.kv_transfer_config.engine_id)
44+
45+
from vllm.platforms import current_platform
46+
device_control_env_var = current_platform.device_control_env_var
47+
world_size = vllm_config.parallel_config.world_size
48+
os.environ[device_control_env_var] = ",".join(
49+
str(current_platform.device_id_to_physical_device_id(i))
50+
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
51+
world_size))
52+
53+
self.dp_rank = dp_rank
54+
55+
def run_busy_loop(self):
56+
"""Core busy loop of the EngineCore for data parallel case."""
57+
# Note: In customized DPEngineCoreProc, no idle time will exist. We assume the another dp groups are always
58+
# running.
59+
60+
# Loop until process is sent a SIGINT or SIGTERM
61+
while True:
62+
# 1) Poll the input queue until there is work to do.
63+
self._process_input_queue()
64+
65+
# 2) Step the engine core.
66+
executed = self._process_engine_step()
67+
self._maybe_publish_request_counts()
68+
69+
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
70+
if not executed:
71+
if not local_unfinished_reqs and not self.engines_running:
72+
# All engines are idle.
73+
continue
74+
75+
# We are in a running state and so must execute a dummy pass
76+
# if the model didn't execute any ready requests.
77+
self.execute_dummy_batch()
78+
79+
80+
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
81+
"""Launch EngineCore busy loop in background process."""
82+
83+
# Signal handler used for graceful termination.
84+
# SystemExit exception is only raised once to allow this and worker
85+
# processes to terminate without error
86+
shutdown_requested = False
87+
88+
# Ensure we can serialize transformer config after spawning
89+
maybe_register_config_serialize_by_value()
90+
91+
def signal_handler(signum, frame):
92+
nonlocal shutdown_requested
93+
if not shutdown_requested:
94+
shutdown_requested = True
95+
raise SystemExit()
96+
97+
# Either SIGTERM or SIGINT will terminate the engine_core
98+
signal.signal(signal.SIGTERM, signal_handler)
99+
signal.signal(signal.SIGINT, signal_handler)
100+
101+
engine_core: Optional[EngineCoreProc] = None
102+
try:
103+
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
104+
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
105+
# Set data parallel rank for this engine process.
106+
parallel_config.data_parallel_rank = dp_rank
107+
parallel_config.data_parallel_rank_local = local_dp_rank
108+
engine_core = ExternealDPEngineCoreProc(*args, **kwargs)
109+
else:
110+
engine_core = EngineCoreProc(*args, **kwargs)
111+
112+
engine_core.run_busy_loop()
113+
114+
except SystemExit:
115+
logger.debug("EngineCore exiting.")
116+
raise
117+
except Exception as e:
118+
if engine_core is None:
119+
logger.exception("EngineCore failed to start.")
120+
else:
121+
logger.exception("EngineCore encountered a fatal error.")
122+
engine_core._send_engine_dead()
123+
raise e
124+
finally:
125+
if engine_core is not None:
126+
engine_core.shutdown()
127+
128+
129+
# Apply this patch only if the external data parallelism is enabled
130+
if vllm_ascend_envs.VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED:
131+
# Patch the EngineCoreClient to use the custom make_async_mp_client
132+
EngineCoreProc.run_engine_core = run_engine_core # type: ignore[attr-defined]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
from vllm.config import VllmConfig
4+
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
5+
MPClient)
6+
from vllm.v1.executor.abstract import Executor
7+
8+
import vllm_ascend.envs as vllm_ascend_envs
9+
10+
11+
def make_async_mp_client(
12+
vllm_config: VllmConfig,
13+
executor_class: type[Executor],
14+
log_stats: bool,
15+
client_addresses: Optional[dict[str, str]] = None,
16+
client_index: int = 0,
17+
) -> "MPClient":
18+
# Use only AsyncMPClient here for dp scenario and use nginx for the dp request routering
19+
return AsyncMPClient(vllm_config, executor_class, log_stats,
20+
client_addresses, client_index)
21+
22+
23+
# Apply this patch only if the external data parallelism is enabled
24+
if vllm_ascend_envs.VLLM_ASCEND_EXTERNAL_DP_LB_ENABLED:
25+
# Patch the EngineCoreClient to use the custom make_async_mp_client
26+
EngineCoreClient.make_async_mp_client = make_async_mp_client # type: ignore[attr-defined]

0 commit comments

Comments
 (0)