Skip to content

[XPU] Add xpu torch.compile support #22609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/scripts/hardware_ci/run-xpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ docker run \
"${image_name}" \
sh -c '
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -O3
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
cd tests
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
) and not current_platform.is_cpu() and not current_platform.is_xpu()

self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
Expand Down
3 changes: 2 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,9 @@ def __init__(
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

# XPU does not support graph currently.
global global_graph_pool
if global_graph_pool is None:
if global_graph_pool is None and not current_platform.is_xpu():
global_graph_pool = current_platform.graph_pool_handle()

# TODO: in the future, if we want to use multiple
Expand Down
5 changes: 5 additions & 0 deletions vllm/compilation/fix_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch._higher_order_ops.auto_functionalize import auto_functionalized

from vllm.logger import init_logger
from vllm.platforms import current_platform

from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
Expand All @@ -32,6 +33,10 @@ def __call__(self, graph: torch.fx.Graph):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
continue
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting

Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True

@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
Comment on lines +81 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using CUDAPiecewiseBackend for XPU is problematic as it contains CUDA-specific code (e.g., torch.cuda.CUDAGraph) that will fail on XPU platforms.

The PR description mentions that XPU does not support graph mode yet, which suggests that graph capture should be disabled. However, compilation_config.use_cudagraph is enabled by default for the V1 engine and is not disabled for the XPU platform. This will cause CUDAPiecewiseBackend to attempt CUDA graph capture, leading to a runtime error.

To fix this, you should disable CUDA graph capture for the XPU platform within torch.compile. A possible fix is to add vllm_config.compilation_config.use_cudagraph = False to the XPUPlatform.check_and_update_config method. Alternatively, you could create a new XPUPiecewiseBackend that does not contain CUDA-specific graph capture logic and use it here.


@classmethod
def inference_mode(cls):
return torch.no_grad()
Expand Down