Skip to content

Commit dcceebc

Browse files
committed
add xpu torch.compile support
Signed-off-by: Kunshang Ji <[email protected]>
1 parent b029de9 commit dcceebc

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

vllm/attention/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(
190190
# opaque custom op. For other platforms, we directly call them
191191
# and let torch.compile handle them.
192192
self.use_direct_call = not current_platform.is_cuda_alike(
193-
) and not current_platform.is_cpu()
193+
) and not current_platform.is_cpu() and not current_platform.is_xpu()
194194

195195
self.use_output = self.attn_backend.accept_output_buffer
196196
compilation_config = get_current_vllm_config().compilation_config

vllm/compilation/fix_functionalization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch._higher_order_ops.auto_functionalize import auto_functionalized
1010

1111
from vllm.logger import init_logger
12+
from vllm.platforms import current_platform
1213

1314
from .fx_utils import is_func
1415
from .vllm_inductor_pass import VllmInductorPass
@@ -32,6 +33,10 @@ def __call__(self, graph: torch.fx.Graph):
3233
self.nodes_to_remove: list[torch.fx.Node] = []
3334
count = 0
3435
for node in graph.nodes:
36+
# XPU does not support auto-functionalization yet.
37+
# Will enable this when switch to vllm-xpu-kernels.
38+
if current_platform.is_xpu():
39+
continue
3540
if not is_func(node, auto_functionalized):
3641
continue # Avoid deep if-elif nesting
3742

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
7878
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
7979
return True
8080

81+
@classmethod
82+
def get_piecewise_backend_cls(cls) -> str:
83+
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
84+
8185
@classmethod
8286
def inference_mode(cls):
8387
return torch.no_grad()

0 commit comments

Comments
 (0)