Skip to content

Commit 11ef7a6

Browse files
yinghainjhillruisearch42
authored
[BugFix] Set CUDA_VISIBLE_DEVICES before spawning the subprocesses (#21211)
Signed-off-by: Yinghai Lu <[email protected]> Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Rui Qiao <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Rui Qiao <[email protected]>
1 parent dc2f159 commit 11ef7a6

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

vllm/v1/engine/core.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -910,22 +910,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
910910
logger.debug("Setting kv_transfer_config.engine_id to %s",
911911
vllm_config.kv_transfer_config.engine_id)
912912

913-
from vllm.platforms import current_platform
914-
device_control_env_var = current_platform.device_control_env_var
915-
world_size = vllm_config.parallel_config.world_size
916-
# Set CUDA_VISIBLE_DEVICES or equivalent.
917-
try:
918-
os.environ[device_control_env_var] = ",".join(
919-
str(current_platform.device_id_to_physical_device_id(i))
920-
for i in range(local_dp_rank *
921-
world_size, (local_dp_rank + 1) * world_size))
922-
except IndexError as e:
923-
raise Exception(
924-
f"Error setting {device_control_env_var}: "
925-
f"local range: [{local_dp_rank * world_size}, "
926-
f"{(local_dp_rank + 1) * world_size}) "
927-
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
928-
929913
self.dp_rank = dp_rank
930914
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
931915

@@ -1088,14 +1072,41 @@ def __init__(
10881072
vllm_config.parallel_config.data_parallel_rank_local = \
10891073
local_dp_rank
10901074

1091-
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
1092-
# we clean this up to be able to properly initialize
1093-
# data parallel groups.
1094-
del os.environ['CUDA_VISIBLE_DEVICES']
1075+
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
1076+
# NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
1077+
# and this cannot be done in the same way for Ray because:
1078+
# 1) Ray manages life cycle of all ray workers (including
1079+
# DPEngineCoreActor)
1080+
# 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
1081+
# To bypass 2, we need to also set
1082+
# RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
1083+
# thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
1084+
# https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1085+
# But vLLM worker assumes visibility into all local GPUs, therefore
1086+
# this results in incorrect indexing into the GPU ID list.
1087+
self._set_cuda_visible_devices(vllm_config, local_dp_rank)
10951088

10961089
super().__init__(vllm_config, local_client, "", executor_class,
10971090
log_stats)
10981091

1092+
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
1093+
local_dp_rank: int):
1094+
from vllm.platforms import current_platform
1095+
device_control_env_var = current_platform.device_control_env_var
1096+
world_size = vllm_config.parallel_config.world_size
1097+
# Set CUDA_VISIBLE_DEVICES or equivalent.
1098+
try:
1099+
os.environ[device_control_env_var] = ",".join(
1100+
str(current_platform.device_id_to_physical_device_id(i))
1101+
for i in range(local_dp_rank *
1102+
world_size, (local_dp_rank + 1) * world_size))
1103+
except IndexError as e:
1104+
raise Exception(
1105+
f"Error setting {device_control_env_var}: "
1106+
f"local range: [{local_dp_rank * world_size}, "
1107+
f"{(local_dp_rank + 1) * world_size}) "
1108+
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
1109+
10991110
def _decorate_logs(self):
11001111
pass
11011112

vllm/v1/engine/utils.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from multiprocessing import Process, connection
1111
from multiprocessing.process import BaseProcess
1212
from typing import TYPE_CHECKING, Callable, Optional, Union
13+
from unittest.mock import patch
1314

1415
import msgspec
1516
import zmq
1617

1718
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
1819
from vllm.logger import init_logger
20+
from vllm.platforms import current_platform
1921
from vllm.ray.ray_env import get_env_vars_to_copy
2022
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
2123
from vllm.v1.engine.coordinator import DPCoordinator
@@ -105,10 +107,13 @@ def __init__(
105107
"client_handshake_address"] = client_handshake_address
106108

107109
self.processes: list[BaseProcess] = []
110+
local_dp_ranks = []
108111
for index in range(local_engine_count):
109112
local_index = local_start_index + index
110113
global_index = start_index + index
114+
111115
# Start EngineCore in background process.
116+
local_dp_ranks.append(local_index)
112117
self.processes.append(
113118
context.Process(target=target_fn,
114119
name=f"EngineCore_{global_index}",
@@ -118,9 +123,14 @@ def __init__(
118123
}))
119124

120125
self._finalizer = weakref.finalize(self, shutdown, self.processes)
126+
127+
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
121128
try:
122-
for proc in self.processes:
123-
proc.start()
129+
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
130+
with set_device_control_env_var(
131+
vllm_config, local_dp_rank) if (
132+
data_parallel) else contextlib.nullcontext():
133+
proc.start()
124134
finally:
125135
# Kill other procs if not all are running.
126136
if self.finished_procs():
@@ -145,6 +155,30 @@ def finished_procs(self) -> dict[str, int]:
145155
}
146156

147157

158+
@contextlib.contextmanager
159+
def set_device_control_env_var(vllm_config: VllmConfig,
160+
local_dp_rank: int) -> Iterator[None]:
161+
"""
162+
Temporarily set CUDA_VISIBLE_DEVICES or equivalent
163+
for engine subprocess.
164+
"""
165+
world_size = vllm_config.parallel_config.world_size
166+
evar = current_platform.device_control_env_var
167+
try:
168+
value = ",".join(
169+
str(current_platform.device_id_to_physical_device_id(i))
170+
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
171+
world_size))
172+
except IndexError as e:
173+
raise Exception(f"Error setting {evar}: "
174+
f"local range: [{local_dp_rank * world_size}, "
175+
f"{(local_dp_rank + 1) * world_size}) "
176+
"base value: "
177+
f"\"{os.getenv(evar)}\"") from e
178+
with patch.dict(os.environ, values=((evar, value), )):
179+
yield
180+
181+
148182
class CoreEngineActorManager:
149183
"""
150184
Utility class to handle creation, readiness, and shutdown
@@ -215,10 +249,9 @@ def __init__(
215249

216250
self.placement_group_is_local = []
217251
refs = []
218-
for index in range(dp_size):
219-
local_index = local_dp_ranks[index]
252+
for index, local_index, pg in zip(range(dp_size), local_dp_ranks,
253+
placement_groups):
220254
dp_vllm_config = copy.deepcopy(vllm_config)
221-
pg = placement_groups[index]
222255
dp_vllm_config.parallel_config.placement_group = pg
223256
local_client = index < local_engine_count
224257
actor = ray.remote(DPEngineCoreActor).options(
@@ -264,7 +297,6 @@ def create_dp_placement_groups(
264297
local_engine_count = \
265298
vllm_config.parallel_config.data_parallel_size_local
266299

267-
nodes = list_nodes()
268300
nodes = sorted(list_nodes(),
269301
key=lambda node: node.node_ip != dp_master_ip)
270302
assert nodes[0].node_ip == dp_master_ip, (

0 commit comments

Comments
 (0)