Skip to content

Commit a491d6f

Browse files
authored
[V1] TP Ray executor (#11107)
Signed-off-by: Rui Qiao <[email protected]>
1 parent 32aa205 commit a491d6f

File tree

5 files changed

+617
-3
lines changed

5 files changed

+617
-3
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_models_distributed(
130130
# Import VLLM_USE_V1 dynamically to handle patching
131131
from vllm.envs import VLLM_USE_V1
132132
if VLLM_USE_V1 and distributed_executor_backend != "mp":
133-
pytest.skip(f"Skip {distributed_executor_backend} for V1")
133+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
134134

135135
dtype = "half"
136136
max_tokens = 5

vllm/v1/engine/llm_engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.v1.engine.detokenizer import Detokenizer
2222
from vllm.v1.engine.processor import Processor
2323
from vllm.v1.executor.abstract import Executor
24+
from vllm.v1.executor.ray_utils import initialize_ray_cluster
2425

2526
logger = init_logger(__name__)
2627

@@ -110,7 +111,11 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
110111
executor_class: Type[Executor]
111112
distributed_executor_backend = (
112113
vllm_config.parallel_config.distributed_executor_backend)
113-
if distributed_executor_backend == "mp":
114+
if distributed_executor_backend == "ray":
115+
initialize_ray_cluster(vllm_config.parallel_config)
116+
from vllm.v1.executor.ray_executor import RayExecutor
117+
executor_class = RayExecutor
118+
elif distributed_executor_backend == "mp":
114119
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
115120
executor_class = MultiprocExecutor
116121
else:

vllm/v1/executor/ray_executor.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
import os
2+
from collections import defaultdict
3+
from itertools import islice, repeat
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
5+
6+
import vllm.envs as envs
7+
from vllm.config import VllmConfig
8+
from vllm.logger import init_logger
9+
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
10+
from vllm.v1.executor.abstract import Executor
11+
from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray
12+
from vllm.v1.outputs import ModelRunnerOutput
13+
14+
if ray is not None:
15+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
16+
17+
if TYPE_CHECKING:
18+
from ray.util.placement_group import PlacementGroup
19+
20+
logger = init_logger(__name__)
21+
22+
23+
class RayExecutor(Executor):
24+
25+
def __init__(self, vllm_config: VllmConfig) -> None:
26+
self.vllm_config = vllm_config
27+
self.parallel_config = vllm_config.parallel_config
28+
self.model_config = vllm_config.model_config
29+
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
30+
31+
# Disable Ray usage stats collection.
32+
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
33+
if ray_usage != "1":
34+
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
35+
36+
placement_group = self.parallel_config.placement_group
37+
# Create the parallel GPU workers.
38+
self._init_workers_ray(placement_group)
39+
40+
def _init_workers_ray(self, placement_group: "PlacementGroup",
41+
**ray_remote_kwargs):
42+
# A list of workers to run a model.
43+
self.workers: List[RayWorkerWrapper] = []
44+
if self.parallel_config.ray_workers_use_nsight:
45+
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
46+
ray_remote_kwargs)
47+
48+
# Create the workers.
49+
driver_ip = get_ip()
50+
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
51+
if not bundle.get("GPU", 0):
52+
# Skip bundles that don't have GPUs,
53+
# as each worker needs one GPU.
54+
continue
55+
scheduling_strategy = PlacementGroupSchedulingStrategy(
56+
placement_group=placement_group,
57+
placement_group_capture_child_tasks=True,
58+
placement_group_bundle_index=bundle_id,
59+
)
60+
61+
worker = ray.remote(
62+
num_cpus=0,
63+
num_gpus=1,
64+
scheduling_strategy=scheduling_strategy,
65+
**ray_remote_kwargs,
66+
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
67+
self.workers.append(worker)
68+
69+
logger.debug("workers: %s", self.workers)
70+
worker_ips = [
71+
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
72+
for worker in self.workers
73+
]
74+
ip_counts: Dict[str, int] = {}
75+
for ip in worker_ips:
76+
ip_counts[ip] = ip_counts.get(ip, 0) + 1
77+
78+
worker_to_ip = dict(zip(self.workers, worker_ips))
79+
80+
def sort_by_driver_then_worker_ip(worker):
81+
"""
82+
Sort the workers based on 3 properties:
83+
1. If the worker is on the same node as the driver (vllm engine),
84+
it should be placed first.
85+
2. Then, if the worker is on a node with fewer workers, it should
86+
be placed first.
87+
3. Finally, if the work is on a node with smaller IP address, it
88+
should be placed first. This is simply a tiebreaker to make
89+
sure the workers are sorted in a deterministic way.
90+
"""
91+
ip = worker_to_ip[worker]
92+
return (ip != driver_ip, ip_counts[ip], ip)
93+
94+
# After sorting, the workers on the same node will be
95+
# close to each other, and the workers on the driver
96+
# node will be placed first.
97+
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
98+
99+
# Get the set of GPU IDs used on each node.
100+
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
101+
102+
node_workers = defaultdict(list) # node id -> list of worker ranks
103+
node_gpus = defaultdict(list) # node id -> list of gpu ids
104+
105+
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
106+
node_workers[node_id].append(i)
107+
# `gpu_ids` can be a list of strings or integers.
108+
# convert them to integers for consistency.
109+
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
110+
# string sorting is not sufficient.
111+
# see https://github.com/vllm-project/vllm/issues/5590
112+
gpu_ids = [int(x) for x in gpu_ids]
113+
node_gpus[node_id].extend(gpu_ids)
114+
115+
for node_id, gpu_ids in node_gpus.items():
116+
node_gpus[node_id] = sorted(gpu_ids)
117+
118+
all_ips = set(worker_ips)
119+
n_ips = len(all_ips)
120+
n_nodes = len(node_workers)
121+
122+
if n_nodes != n_ips:
123+
raise RuntimeError(
124+
f"Every node should have a unique IP address. Got {n_nodes}"
125+
f" nodes with node ids {list(node_workers.keys())} and "
126+
f"{n_ips} unique IP addresses {all_ips}. Please check your"
127+
" network configuration. If you set `VLLM_HOST_IP` or "
128+
"`HOST_IP` environment variable, make sure it is unique for"
129+
" each node.")
130+
131+
# Set environment variables for the driver and workers.
132+
all_args_to_update_environment_variables = [({
133+
"CUDA_VISIBLE_DEVICES":
134+
",".join(map(str, node_gpus[node_id])),
135+
"VLLM_TRACE_FUNCTION":
136+
str(envs.VLLM_TRACE_FUNCTION),
137+
"VLLM_USE_V1":
138+
str(int(envs.VLLM_USE_V1)),
139+
**({
140+
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
141+
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
142+
}, ) for (node_id, _) in worker_node_and_gpu_ids]
143+
144+
self._env_vars_for_all_workers = (
145+
all_args_to_update_environment_variables)
146+
147+
self._run_workers("update_environment_variables",
148+
all_args=self._get_env_vars_to_be_updated())
149+
150+
if len(node_gpus) == 1:
151+
# in single node case, we don't need to get the IP address.
152+
# the loopback address is sufficient
153+
# NOTE: a node may have several IP addresses, one for each
154+
# network interface. `get_ip()` might return any of them,
155+
# while they might not work for communication inside the node
156+
# if the network setup is complicated. Using the loopback address
157+
# solves this issue, as it always works for communication inside
158+
# the node.
159+
driver_ip = "127.0.0.1"
160+
distributed_init_method = get_distributed_init_method(
161+
driver_ip, get_open_port())
162+
163+
# Initialize the actual workers inside worker wrapper.
164+
init_worker_all_kwargs = [
165+
self._get_worker_kwargs(
166+
local_rank=node_workers[node_id].index(rank),
167+
rank=rank,
168+
distributed_init_method=distributed_init_method,
169+
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
170+
]
171+
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
172+
self._run_workers("initialize")
173+
self._run_workers("load_model")
174+
175+
def _configure_ray_workers_use_nsight(self,
176+
ray_remote_kwargs) -> Dict[str, Any]:
177+
# If nsight profiling is enabled, we need to set the profiling
178+
# configuration for the ray workers as runtime env.
179+
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
180+
runtime_env.update({
181+
"nsight": {
182+
"t": "cuda,cudnn,cublas",
183+
"o": "'worker_process_%p'",
184+
"cuda-graph-trace": "node",
185+
}
186+
})
187+
188+
return ray_remote_kwargs
189+
190+
def _get_env_vars_to_be_updated(self):
191+
return self._env_vars_for_all_workers
192+
193+
def _get_worker_kwargs(
194+
self,
195+
local_rank: int = 0,
196+
rank: int = 0,
197+
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
198+
"""
199+
Return worker init args for a given rank.
200+
"""
201+
if distributed_init_method is None:
202+
distributed_init_method = get_distributed_init_method(
203+
get_ip(), get_open_port())
204+
return dict(
205+
vllm_config=self.vllm_config,
206+
local_rank=local_rank,
207+
rank=rank,
208+
distributed_init_method=distributed_init_method,
209+
)
210+
211+
def determine_num_available_blocks(self) -> Tuple[int, int]:
212+
"""
213+
Determine the number of available KV blocks.
214+
215+
This invokes `determine_num_available_blocks` on each worker and takes
216+
the min of the results, guaranteeing that the selected cache sizes are
217+
compatible with all workers.
218+
219+
Returns:
220+
- tuple[num_gpu_blocks, num_cpu_blocks]
221+
"""
222+
# Get the maximum number of blocks that can be allocated on GPU and CPU.
223+
num_blocks = self._run_workers("determine_num_available_blocks")
224+
225+
# Since we use a shared centralized controller, we take the minimum
226+
# number of blocks across all workers to make sure all the memory
227+
# operators can be applied to all workers.
228+
num_gpu_blocks = min(b[0] for b in num_blocks)
229+
num_cpu_blocks = min(b[1] for b in num_blocks)
230+
231+
return num_gpu_blocks, num_cpu_blocks
232+
233+
def initialize(self, num_gpu_blocks: int) -> None:
234+
"""
235+
Initialize the KV cache in all workers.
236+
"""
237+
# NOTE: This is logged in the executor because there can be >1 worker
238+
# with other executors. We could log in the engine level, but work
239+
# remains to abstract away the device for non-GPU configurations.
240+
logger.info("# GPU blocks: %d", num_gpu_blocks)
241+
self._run_workers("initialize_cache", num_gpu_blocks)
242+
self._run_workers("compile_or_warm_up_model")
243+
244+
def _run_workers(
245+
self,
246+
method: str,
247+
*args,
248+
all_args: Optional[List[Tuple[Any, ...]]] = None,
249+
all_kwargs: Optional[List[Dict[str, Any]]] = None,
250+
**kwargs,
251+
) -> Any:
252+
"""
253+
Runs the given method on all workers. Can be used in the following
254+
ways:
255+
256+
Args:
257+
- args/kwargs: All workers share the same args/kwargs
258+
- all_args/all_kwargs: args/kwargs for each worker are specified
259+
individually
260+
"""
261+
count = len(self.workers)
262+
all_worker_args = repeat(args, count) if all_args is None \
263+
else islice(all_args, 0, None)
264+
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
265+
else islice(all_kwargs, 0, None)
266+
267+
ray_worker_refs = [
268+
worker.execute_method.remote( # type: ignore[attr-defined]
269+
method, *worker_args, **worker_kwargs)
270+
for (worker, worker_args, worker_kwargs
271+
) in zip(self.workers, all_worker_args, all_worker_kwargs)
272+
]
273+
return ray.get(ray_worker_refs)
274+
275+
def execute_model(
276+
self,
277+
scheduler_output,
278+
) -> ModelRunnerOutput:
279+
if self.forward_dag is None:
280+
self.forward_dag = self._compiled_ray_dag()
281+
# Only the first worker (with rank 0) returns the execution result.
282+
# Others return None.
283+
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
284+
return output
285+
286+
def profile(self, is_start=True):
287+
raise NotImplementedError
288+
289+
def shutdown(self):
290+
if hasattr(self, "forward_dag") and self.forward_dag is not None:
291+
self.forward_dag.teardown()
292+
import ray
293+
for worker in self.workers:
294+
ray.kill(worker)
295+
self.forward_dag = None
296+
297+
def check_health(self) -> None:
298+
logger.debug("Called check_health.")
299+
300+
def _check_ray_compiled_graph_installation(self):
301+
import pkg_resources
302+
from packaging import version
303+
304+
required_version = version.parse("2.39")
305+
current_version = version.parse(
306+
pkg_resources.get_distribution("ray").version)
307+
if current_version < required_version:
308+
raise ValueError(f"Ray version {required_version} is "
309+
f"required, but found {current_version}")
310+
311+
import importlib.util
312+
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
313+
if raycg is None:
314+
raise ValueError("Ray Compiled Graph is not installed. "
315+
"Run `pip install ray[adag]` to install it.")
316+
317+
cupy_spec = importlib.util.find_spec("cupy")
318+
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
319+
raise ValueError(
320+
"cupy is not installed but required since "
321+
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
322+
"Run `pip install ray[adag]` and check cupy installation.")
323+
324+
def _compiled_ray_dag(self):
325+
assert self.parallel_config.use_ray
326+
self._check_ray_compiled_graph_installation()
327+
from ray.dag import InputNode, MultiOutputNode
328+
329+
with InputNode() as input_batches:
330+
outputs = [
331+
worker.execute_model.bind( # type: ignore[attr-defined]
332+
input_batches) for worker in self.workers
333+
]
334+
forward_dag = MultiOutputNode(outputs)
335+
336+
return forward_dag.experimental_compile()
337+
338+
def __del__(self):
339+
self.shutdown()

0 commit comments

Comments
 (0)