Skip to content

Fix Ray placement group allocation is not respecting env VLLM_RAY_PER_WORKER_GPUS (fractional gpu) #22577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import msgspec

import vllm.envs as envs
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
Expand Down Expand Up @@ -338,6 +339,7 @@
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
device_resource_request = envs.VLLM_RAY_PER_WORKER_GPUS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a fractional VLLM_RAY_PER_WORKER_GPUS when world_size > 1 can lead to multiple workers from the same tensor-parallel group being scheduled on the same GPU. This is generally not supported and can cause failures.

While you've noted this is for a specific use case, this change could unintentionally affect users with multi-GPU setups. The previous implementation requested a full GPU (1.0) for each worker in the placement group, which prevented this co-location scenario. This PR changes that behavior.

To mitigate this risk for other users, please add a warning when world_size > 1 and a fractional GPU value is used.

        device_resource_request = envs.VLLM_RAY_PER_WORKER_GPUS
        if parallel_config.world_size > 1 and device_resource_request < 1.0:
            logger.warning(
                "VLLM_RAY_PER_WORKER_GPUS is set to %f, which is less than 1.0. "
                "When using multi-GPU inference (world_size > 1), this can "
                "cause multiple workers to be placed on the same GPU, which "
                "is not supported and may lead to unexpected behavior or "
                "failures. Please ensure that each worker is placed on a "
                "separate GPU.", device_resource_request)

num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
Expand All @@ -349,7 +351,8 @@
device_str)
# Create a new placement group
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
device_str:
device_resource_request
} for _ in range(parallel_config.world_size)])

# vLLM engine is also a worker to execute model with an accelerator,
Expand All @@ -358,12 +361,13 @@
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
if current_node_resource.get(device_str, 0) < device_resource_request:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
f"{device_str}. Make sure you have at least {device_resource_request} "

Check failure on line 368 in vllm/executor/ray_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/executor/ray_utils.py:368:81: E501 Line too long (87 > 80)
f"{device_str} available in a node {current_node_id=} {current_ip=}."

Check failure on line 369 in vllm/executor/ray_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/executor/ray_utils.py:369:81: E501 Line too long (85 > 80)
)
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
Expand Down
Loading