10
10
from multiprocessing import Process , connection
11
11
from multiprocessing .process import BaseProcess
12
12
from typing import TYPE_CHECKING , Callable , Optional , Union
13
+ from unittest .mock import patch
13
14
14
15
import msgspec
15
16
import zmq
16
17
17
18
from vllm .config import CacheConfig , ParallelConfig , VllmConfig
18
19
from vllm .logger import init_logger
20
+ from vllm .platforms import current_platform
19
21
from vllm .ray .ray_env import get_env_vars_to_copy
20
22
from vllm .utils import get_mp_context , get_open_zmq_ipc_path , zmq_socket_ctx
21
23
from vllm .v1 .engine .coordinator import DPCoordinator
@@ -105,10 +107,13 @@ def __init__(
105
107
"client_handshake_address" ] = client_handshake_address
106
108
107
109
self .processes : list [BaseProcess ] = []
110
+ local_dp_ranks = []
108
111
for index in range (local_engine_count ):
109
112
local_index = local_start_index + index
110
113
global_index = start_index + index
114
+
111
115
# Start EngineCore in background process.
116
+ local_dp_ranks .append (local_index )
112
117
self .processes .append (
113
118
context .Process (target = target_fn ,
114
119
name = f"EngineCore_{ global_index } " ,
@@ -118,9 +123,14 @@ def __init__(
118
123
}))
119
124
120
125
self ._finalizer = weakref .finalize (self , shutdown , self .processes )
126
+
127
+ data_parallel = vllm_config .parallel_config .data_parallel_size > 1
121
128
try :
122
- for proc in self .processes :
123
- proc .start ()
129
+ for proc , local_dp_rank in zip (self .processes , local_dp_ranks ):
130
+ with set_device_control_env_var (
131
+ vllm_config , local_dp_rank ) if (
132
+ data_parallel ) else contextlib .nullcontext ():
133
+ proc .start ()
124
134
finally :
125
135
# Kill other procs if not all are running.
126
136
if self .finished_procs ():
@@ -145,6 +155,30 @@ def finished_procs(self) -> dict[str, int]:
145
155
}
146
156
147
157
158
+ @contextlib .contextmanager
159
+ def set_device_control_env_var (vllm_config : VllmConfig ,
160
+ local_dp_rank : int ) -> Iterator [None ]:
161
+ """
162
+ Temporarily set CUDA_VISIBLE_DEVICES or equivalent
163
+ for engine subprocess.
164
+ """
165
+ world_size = vllm_config .parallel_config .world_size
166
+ evar = current_platform .device_control_env_var
167
+ try :
168
+ value = "," .join (
169
+ str (current_platform .device_id_to_physical_device_id (i ))
170
+ for i in range (local_dp_rank * world_size , (local_dp_rank + 1 ) *
171
+ world_size ))
172
+ except IndexError as e :
173
+ raise Exception (f"Error setting { evar } : "
174
+ f"local range: [{ local_dp_rank * world_size } , "
175
+ f"{ (local_dp_rank + 1 ) * world_size } ) "
176
+ "base value: "
177
+ f"\" { os .getenv (evar )} \" " ) from e
178
+ with patch .dict (os .environ , values = ((evar , value ), )):
179
+ yield
180
+
181
+
148
182
class CoreEngineActorManager :
149
183
"""
150
184
Utility class to handle creation, readiness, and shutdown
@@ -215,10 +249,9 @@ def __init__(
215
249
216
250
self .placement_group_is_local = []
217
251
refs = []
218
- for index in range (dp_size ):
219
- local_index = local_dp_ranks [ index ]
252
+ for index , local_index , pg in zip ( range (dp_size ), local_dp_ranks ,
253
+ placement_groups ):
220
254
dp_vllm_config = copy .deepcopy (vllm_config )
221
- pg = placement_groups [index ]
222
255
dp_vllm_config .parallel_config .placement_group = pg
223
256
local_client = index < local_engine_count
224
257
actor = ray .remote (DPEngineCoreActor ).options (
@@ -264,7 +297,6 @@ def create_dp_placement_groups(
264
297
local_engine_count = \
265
298
vllm_config .parallel_config .data_parallel_size_local
266
299
267
- nodes = list_nodes ()
268
300
nodes = sorted (list_nodes (),
269
301
key = lambda node : node .node_ip != dp_master_ip )
270
302
assert nodes [0 ].node_ip == dp_master_ip , (
0 commit comments