2
2
import os
3
3
import pickle
4
4
from collections import defaultdict
5
+ from itertools import islice , repeat
5
6
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple
6
7
7
8
from vllm .engine .ray_utils import RayWorkerWrapper , ray
@@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
136
137
VLLM_INSTANCE_ID = get_vllm_instance_id ()
137
138
138
139
# Set environment variables for the driver and workers.
139
- all_args_to_update_environment_variables = []
140
- for (node_id , _ ) in worker_node_and_gpu_ids :
141
- all_args_to_update_environment_variables .append ([{
142
- "CUDA_VISIBLE_DEVICES" :
143
- "," .join (map (str , node_gpus [node_id ])),
144
- "VLLM_INSTANCE_ID" :
145
- VLLM_INSTANCE_ID ,
146
- "VLLM_TRACE_FUNCTION" :
147
- os .getenv ("VLLM_TRACE_FUNCTION" , "0" ),
148
- }])
140
+ all_args_to_update_environment_variables = [({
141
+ "CUDA_VISIBLE_DEVICES" :
142
+ "," .join (map (str , node_gpus [node_id ])),
143
+ "VLLM_INSTANCE_ID" :
144
+ VLLM_INSTANCE_ID ,
145
+ "VLLM_TRACE_FUNCTION" :
146
+ os .getenv ("VLLM_TRACE_FUNCTION" , "0" ),
147
+ }, ) for (node_id , _ ) in worker_node_and_gpu_ids ]
149
148
self ._run_workers ("update_environment_variables" ,
150
149
all_args = all_args_to_update_environment_variables )
151
150
@@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs):
156
155
# avoid writing `{"name": value}` manually
157
156
return kwargs
158
157
159
- init_worker_all_kwargs = []
160
-
161
158
# Initialize the actual workers inside worker wrapper.
162
- for rank , (node_id , _ ) in enumerate (worker_node_and_gpu_ids , ):
159
+ init_worker_all_kwargs = []
160
+ for rank , (node_id , _ ) in enumerate (worker_node_and_gpu_ids ):
163
161
local_rank = node_workers [node_id ].index (rank )
164
162
init_worker_all_kwargs .append (
165
163
collect_arg_helper_func (
@@ -265,40 +263,40 @@ def _run_workers(
265
263
self ,
266
264
method : str ,
267
265
* args ,
268
- driver_args : Optional [Tuple [Any ]] = None ,
266
+ driver_args : Optional [Tuple [Any , ... ]] = None ,
269
267
driver_kwargs : Optional [Dict [str , Any ]] = None ,
270
- all_args : Optional [List [List [Any ]]] = None ,
268
+ all_args : Optional [List [Tuple [Any , ... ]]] = None ,
271
269
all_kwargs : Optional [List [Dict [str , Any ]]] = None ,
272
270
use_dummy_driver : bool = False ,
273
271
max_concurrent_workers : Optional [int ] = None ,
274
272
use_ray_compiled_dag : bool = False ,
275
273
** kwargs ,
276
274
) -> Any :
277
- """Runs the given method on all workers.
278
- all_args and all_kwargs are used to pass heterogeneous arguments,
279
- i.e. different arguments for each worker.
275
+ """Runs the given method on all workers. Can be used in the following
276
+ ways:
277
+
278
+ - args/kwargs: All workers share the same args/kwargs
279
+ - args/kwargs and driver_args/driver_kwargs: Driver worker has
280
+ different args
281
+ - all_args/all_kwargs: args/kwargs for each worker are specified
282
+ individually
280
283
"""
281
- if driver_args is None :
282
- driver_args = args
283
- if driver_kwargs is None :
284
- driver_kwargs = kwargs
285
-
286
- # for mypy type checking
287
- assert driver_args is not None
288
- assert driver_kwargs is not None
289
- if all_args is None :
290
- all_args = [driver_args ] + [args ] * len (self .workers )
291
- if all_kwargs is None :
292
- all_kwargs = [driver_kwargs ] + [kwargs ] * len (self .workers )
293
-
294
- # for mypy type checking
295
- assert all_args is not None
296
- assert all_kwargs is not None
297
284
298
285
if max_concurrent_workers :
299
286
raise NotImplementedError (
300
287
"max_concurrent_workers is not supported yet." )
301
288
289
+ if driver_args is None :
290
+ driver_args = args if all_args is None else all_args [0 ]
291
+ if driver_kwargs is None :
292
+ driver_kwargs = kwargs if all_kwargs is None else all_kwargs [0 ]
293
+
294
+ count = len (self .workers )
295
+ all_worker_args = repeat (args , count ) if all_args is None \
296
+ else islice (all_args , 1 , None )
297
+ all_worker_kwargs = repeat (kwargs , count ) if all_kwargs is None \
298
+ else islice (all_kwargs , 1 , None )
299
+
302
300
if use_ray_compiled_dag :
303
301
# Right now, compiled DAG can only accept a single
304
302
# input. TODO(sang): Fix it.
@@ -310,22 +308,17 @@ def _run_workers(
310
308
worker .execute_method .remote (method , * worker_args ,
311
309
** worker_kwargs )
312
310
for (worker , worker_args , worker_kwargs
313
- ) in zip (self .workers , all_args [ 1 :], all_kwargs [ 1 :] )
311
+ ) in zip (self .workers , all_worker_args , all_worker_kwargs )
314
312
]
315
313
316
- if driver_args is None :
317
- driver_args = args
318
- if driver_kwargs is None :
319
- driver_kwargs = kwargs
320
-
321
314
# Start the driver worker after all the ray workers.
322
315
if not use_dummy_driver :
323
316
driver_worker_output = self .driver_worker .execute_method (
324
- method , * all_args [ 0 ] , ** all_kwargs [ 0 ] )
317
+ method , * driver_args , ** driver_kwargs )
325
318
else :
326
319
driver_worker_output = ray .get (
327
320
self .driver_dummy_worker .execute_method .remote (
328
- method , * all_args [ 0 ] , ** all_kwargs [ 0 ] ))
321
+ method , * driver_args , ** driver_kwargs ))
329
322
# Get the results of the ray workers.
330
323
if self .workers :
331
324
if use_ray_compiled_dag :
@@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self):
383
376
384
377
class RayGPUExecutorAsync (RayGPUExecutor , ExecutorAsyncBase ):
385
378
379
+ def __init__ (self , * args , ** kwargs ):
380
+ super ().__init__ (* args , ** kwargs )
381
+ self .driver_executor = make_async (self .driver_worker .execute_method )
382
+
386
383
async def _run_workers_async (
387
384
self ,
388
385
method : str ,
@@ -399,13 +396,8 @@ async def _run_workers_async(
399
396
if driver_kwargs is None :
400
397
driver_kwargs = kwargs
401
398
402
- # Run the driver worker asynchronously.
403
- def helper ():
404
- return self .driver_worker .execute_method (method , * driver_args ,
405
- ** driver_kwargs )
406
-
407
- driver_executor = make_async (helper )
408
- coros .append (driver_executor ())
399
+ coros .append (
400
+ self .driver_executor (method , * driver_args , ** driver_kwargs ))
409
401
410
402
# Run the ray workers asynchronously.
411
403
for worker in self .workers :
0 commit comments