1010from multiprocessing import Process , connection
1111from multiprocessing .process import BaseProcess
1212from typing import TYPE_CHECKING , Callable , Optional , Union
13+ from unittest .mock import patch
1314
1415import msgspec
1516import zmq
1617
1718from vllm .config import CacheConfig , ParallelConfig , VllmConfig
1819from vllm .logger import init_logger
20+ from vllm .platforms import current_platform
1921from vllm .ray .ray_env import get_env_vars_to_copy
2022from vllm .utils import get_mp_context , get_open_zmq_ipc_path , zmq_socket_ctx
2123from vllm .v1 .engine .coordinator import DPCoordinator
@@ -105,10 +107,13 @@ def __init__(
105107 "client_handshake_address" ] = client_handshake_address
106108
107109 self .processes : list [BaseProcess ] = []
110+ local_dp_ranks = []
108111 for index in range (local_engine_count ):
109112 local_index = local_start_index + index
110113 global_index = start_index + index
114+
111115 # Start EngineCore in background process.
116+ local_dp_ranks .append (local_index )
112117 self .processes .append (
113118 context .Process (target = target_fn ,
114119 name = f"EngineCore_{ global_index } " ,
@@ -118,9 +123,14 @@ def __init__(
118123 }))
119124
120125 self ._finalizer = weakref .finalize (self , shutdown , self .processes )
126+
127+ data_parallel = vllm_config .parallel_config .data_parallel_size > 1
121128 try :
122- for proc in self .processes :
123- proc .start ()
129+ for proc , local_dp_rank in zip (self .processes , local_dp_ranks ):
130+ with set_device_control_env_var (
131+ vllm_config , local_dp_rank ) if (
132+ data_parallel ) else contextlib .nullcontext ():
133+ proc .start ()
124134 finally :
125135 # Kill other procs if not all are running.
126136 if self .finished_procs ():
@@ -145,6 +155,30 @@ def finished_procs(self) -> dict[str, int]:
145155 }
146156
147157
158+ @contextlib .contextmanager
159+ def set_device_control_env_var (vllm_config : VllmConfig ,
160+ local_dp_rank : int ) -> Iterator [None ]:
161+ """
162+ Temporarily set CUDA_VISIBLE_DEVICES or equivalent
163+ for engine subprocess.
164+ """
165+ world_size = vllm_config .parallel_config .world_size
166+ evar = current_platform .device_control_env_var
167+ try :
168+ value = "," .join (
169+ str (current_platform .device_id_to_physical_device_id (i ))
170+ for i in range (local_dp_rank * world_size , (local_dp_rank + 1 ) *
171+ world_size ))
172+ except IndexError as e :
173+ raise Exception (f"Error setting { evar } : "
174+ f"local range: [{ local_dp_rank * world_size } , "
175+ f"{ (local_dp_rank + 1 ) * world_size } ) "
176+ "base value: "
177+ f"\" { os .getenv (evar )} \" " ) from e
178+ with patch .dict (os .environ , values = ((evar , value ), )):
179+ yield
180+
181+
148182class CoreEngineActorManager :
149183 """
150184 Utility class to handle creation, readiness, and shutdown
@@ -215,10 +249,9 @@ def __init__(
215249
216250 self .placement_group_is_local = []
217251 refs = []
218- for index in range (dp_size ):
219- local_index = local_dp_ranks [ index ]
252+ for index , local_index , pg in zip ( range (dp_size ), local_dp_ranks ,
253+ placement_groups ):
220254 dp_vllm_config = copy .deepcopy (vllm_config )
221- pg = placement_groups [index ]
222255 dp_vllm_config .parallel_config .placement_group = pg
223256 local_client = index < local_engine_count
224257 actor = ray .remote (DPEngineCoreActor ).options (
@@ -264,7 +297,6 @@ def create_dp_placement_groups(
264297 local_engine_count = \
265298 vllm_config .parallel_config .data_parallel_size_local
266299
267- nodes = list_nodes ()
268300 nodes = sorted (list_nodes (),
269301 key = lambda node : node .node_ip != dp_master_ip )
270302 assert nodes [0 ].node_ip == dp_master_ip , (
0 commit comments