Skip to content

Commit 8f2ea22

Browse files
authored
[Core] Some simplification of WorkerWrapper changes (#4183)
1 parent 0ae11f7 commit 8f2ea22

File tree

2 files changed

+45
-54
lines changed

2 files changed

+45
-54
lines changed

vllm/executor/ray_gpu_executor.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import pickle
44
from collections import defaultdict
5+
from itertools import islice, repeat
56
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
67

78
from vllm.engine.ray_utils import RayWorkerWrapper, ray
@@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
136137
VLLM_INSTANCE_ID = get_vllm_instance_id()
137138

138139
# Set environment variables for the driver and workers.
139-
all_args_to_update_environment_variables = []
140-
for (node_id, _) in worker_node_and_gpu_ids:
141-
all_args_to_update_environment_variables.append([{
142-
"CUDA_VISIBLE_DEVICES":
143-
",".join(map(str, node_gpus[node_id])),
144-
"VLLM_INSTANCE_ID":
145-
VLLM_INSTANCE_ID,
146-
"VLLM_TRACE_FUNCTION":
147-
os.getenv("VLLM_TRACE_FUNCTION", "0"),
148-
}])
140+
all_args_to_update_environment_variables = [({
141+
"CUDA_VISIBLE_DEVICES":
142+
",".join(map(str, node_gpus[node_id])),
143+
"VLLM_INSTANCE_ID":
144+
VLLM_INSTANCE_ID,
145+
"VLLM_TRACE_FUNCTION":
146+
os.getenv("VLLM_TRACE_FUNCTION", "0"),
147+
}, ) for (node_id, _) in worker_node_and_gpu_ids]
149148
self._run_workers("update_environment_variables",
150149
all_args=all_args_to_update_environment_variables)
151150

@@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs):
156155
# avoid writing `{"name": value}` manually
157156
return kwargs
158157

159-
init_worker_all_kwargs = []
160-
161158
# Initialize the actual workers inside worker wrapper.
162-
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
159+
init_worker_all_kwargs = []
160+
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
163161
local_rank = node_workers[node_id].index(rank)
164162
init_worker_all_kwargs.append(
165163
collect_arg_helper_func(
@@ -265,40 +263,40 @@ def _run_workers(
265263
self,
266264
method: str,
267265
*args,
268-
driver_args: Optional[Tuple[Any]] = None,
266+
driver_args: Optional[Tuple[Any, ...]] = None,
269267
driver_kwargs: Optional[Dict[str, Any]] = None,
270-
all_args: Optional[List[List[Any]]] = None,
268+
all_args: Optional[List[Tuple[Any, ...]]] = None,
271269
all_kwargs: Optional[List[Dict[str, Any]]] = None,
272270
use_dummy_driver: bool = False,
273271
max_concurrent_workers: Optional[int] = None,
274272
use_ray_compiled_dag: bool = False,
275273
**kwargs,
276274
) -> Any:
277-
"""Runs the given method on all workers.
278-
all_args and all_kwargs are used to pass heterogeneous arguments,
279-
i.e. different arguments for each worker.
275+
"""Runs the given method on all workers. Can be used in the following
276+
ways:
277+
278+
- args/kwargs: All workers share the same args/kwargs
279+
- args/kwargs and driver_args/driver_kwargs: Driver worker has
280+
different args
281+
- all_args/all_kwargs: args/kwargs for each worker are specified
282+
individually
280283
"""
281-
if driver_args is None:
282-
driver_args = args
283-
if driver_kwargs is None:
284-
driver_kwargs = kwargs
285-
286-
# for mypy type checking
287-
assert driver_args is not None
288-
assert driver_kwargs is not None
289-
if all_args is None:
290-
all_args = [driver_args] + [args] * len(self.workers)
291-
if all_kwargs is None:
292-
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
293-
294-
# for mypy type checking
295-
assert all_args is not None
296-
assert all_kwargs is not None
297284

298285
if max_concurrent_workers:
299286
raise NotImplementedError(
300287
"max_concurrent_workers is not supported yet.")
301288

289+
if driver_args is None:
290+
driver_args = args if all_args is None else all_args[0]
291+
if driver_kwargs is None:
292+
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
293+
294+
count = len(self.workers)
295+
all_worker_args = repeat(args, count) if all_args is None \
296+
else islice(all_args, 1, None)
297+
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
298+
else islice(all_kwargs, 1, None)
299+
302300
if use_ray_compiled_dag:
303301
# Right now, compiled DAG can only accept a single
304302
# input. TODO(sang): Fix it.
@@ -310,22 +308,17 @@ def _run_workers(
310308
worker.execute_method.remote(method, *worker_args,
311309
**worker_kwargs)
312310
for (worker, worker_args, worker_kwargs
313-
) in zip(self.workers, all_args[1:], all_kwargs[1:])
311+
) in zip(self.workers, all_worker_args, all_worker_kwargs)
314312
]
315313

316-
if driver_args is None:
317-
driver_args = args
318-
if driver_kwargs is None:
319-
driver_kwargs = kwargs
320-
321314
# Start the driver worker after all the ray workers.
322315
if not use_dummy_driver:
323316
driver_worker_output = self.driver_worker.execute_method(
324-
method, *all_args[0], **all_kwargs[0])
317+
method, *driver_args, **driver_kwargs)
325318
else:
326319
driver_worker_output = ray.get(
327320
self.driver_dummy_worker.execute_method.remote(
328-
method, *all_args[0], **all_kwargs[0]))
321+
method, *driver_args, **driver_kwargs))
329322
# Get the results of the ray workers.
330323
if self.workers:
331324
if use_ray_compiled_dag:
@@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self):
383376

384377
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
385378

379+
def __init__(self, *args, **kwargs):
380+
super().__init__(*args, **kwargs)
381+
self.driver_executor = make_async(self.driver_worker.execute_method)
382+
386383
async def _run_workers_async(
387384
self,
388385
method: str,
@@ -399,13 +396,8 @@ async def _run_workers_async(
399396
if driver_kwargs is None:
400397
driver_kwargs = kwargs
401398

402-
# Run the driver worker asynchronously.
403-
def helper():
404-
return self.driver_worker.execute_method(method, *driver_args,
405-
**driver_kwargs)
406-
407-
driver_executor = make_async(helper)
408-
coros.append(driver_executor())
399+
coros.append(
400+
self.driver_executor(method, *driver_args, **driver_kwargs))
409401

410402
# Run the ray workers asynchronously.
411403
for worker in self.workers:

vllm/worker/worker_base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(self,
108108
self.worker_class_name = worker_class_name
109109
self.worker = None
110110

111-
def update_environment_variables(self, envs: Dict[str, str]) -> None:
111+
@staticmethod
112+
def update_environment_variables(envs: Dict[str, str]) -> None:
112113
key = 'CUDA_VISIBLE_DEVICES'
113114
if key in envs and key in os.environ:
114115
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
@@ -138,10 +139,8 @@ def init_worker(self, *args, **kwargs):
138139

139140
def execute_method(self, method, *args, **kwargs):
140141
try:
141-
if hasattr(self, method):
142-
executor = getattr(self, method)
143-
else:
144-
executor = getattr(self.worker, method)
142+
target = self if self.worker is None else self.worker
143+
executor = getattr(target, method)
145144
return executor(*args, **kwargs)
146145
except Exception as e:
147146
# if the driver worker also execute methods,

0 commit comments

Comments
 (0)