Skip to content

Commit fce10db

Browse files
authored
[XPU] Add xpu torch.compile support (#22609)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent d272415 commit fce10db

File tree

8 files changed

+36
-11
lines changed

8 files changed

+36
-11
lines changed

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ docker run \
3131
set -e
3232
echo $ZE_AFFINITY_MASK
3333
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
34+
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
3435
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
3536
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
3637
cd tests

vllm/attention/layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def __init__(
190190
# torch.compile works by registering the attention as one giant
191191
# opaque custom op. For other platforms, we directly call them
192192
# and let torch.compile handle them.
193-
self.use_direct_call = not current_platform.is_cuda_alike(
194-
) and not current_platform.is_cpu()
193+
self.use_direct_call = not current_platform.opaque_attention_op()
195194

196195
self.use_output = self.attn_backend.accept_output_buffer
197196
compilation_config = get_current_vllm_config().compilation_config

vllm/compilation/fix_functionalization.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch._higher_order_ops.auto_functionalize import auto_functionalized
1010

1111
from vllm.logger import init_logger
12+
from vllm.platforms import current_platform
1213

1314
from .fx_utils import is_func
1415
from .vllm_inductor_pass import VllmInductorPass
@@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
2627
"""
2728

2829
def __call__(self, graph: torch.fx.Graph):
30+
# XPU does not support auto-functionalization yet.
31+
# Will enable this when switch to vllm-xpu-kernels.
32+
if current_platform.is_xpu():
33+
logger.debug("XPU platform does not support fix functionalization"
34+
"pass currently.")
35+
return
36+
2937
self.begin()
3038
self.dump_graph(graph, "before_fix_functionalization")
3139

vllm/platforms/cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,7 @@ def default_v1(cls, model_config) -> bool:
335335
return (cls.supports_v1(model_config)
336336
and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC,
337337
CpuArchEnum.ARM, CpuArchEnum.S390X))
338+
339+
@classmethod
340+
def opaque_attention_op(cls) -> bool:
341+
return True

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool:
442442
def use_custom_allreduce(cls) -> bool:
443443
return True
444444

445+
@classmethod
446+
def opaque_attention_op(cls) -> bool:
447+
return True
448+
445449
@classmethod
446450
def get_static_graph_wrapper_cls(cls) -> str:
447451
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

vllm/platforms/interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,14 @@ def use_custom_allreduce(cls) -> bool:
509509
"""
510510
return False
511511

512+
@classmethod
513+
def opaque_attention_op(cls) -> bool:
514+
"""
515+
Returns True if we register attention as one giant opaque custom op
516+
on the current platform
517+
"""
518+
return False
519+
512520
@classmethod
513521
def validate_request(
514522
cls,

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,10 @@ def use_custom_allreduce(cls) -> bool:
411411
supported_archs = ['gfx94', 'gfx95']
412412
return any(gfx in gcn_arch for gfx in supported_archs)
413413

414+
@classmethod
415+
def opaque_attention_op(cls) -> bool:
416+
return True
417+
414418
@classmethod
415419
def get_cu_count(cls, device_id: int = 0) -> int:
416420
return torch.cuda.get_device_properties(

vllm/platforms/xpu.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9090
if cache_config and cache_config.block_size is None:
9191
cache_config.block_size = 64
9292

93-
# FIXME: Temporarily forcing eager mode
94-
# remove after t.compile support stabilizes.
95-
if (envs.VLLM_USE_V1 and model_config is not None
96-
and not vllm_config.model_config.enforce_eager):
97-
from vllm.config import CompilationLevel
98-
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
99-
10093
# lazy import to avoid circular import
10194
from vllm.config import CUDAGraphMode
10295
compilation_config = vllm_config.compilation_config
10396
if compilation_config.cudagraph_mode is None or \
10497
compilation_config.cudagraph_mode.max_cudagraph_mode() \
10598
!= CUDAGraphMode.NONE:
106-
logger.info("[XPU] CUDA graph is not supported on XPU, "
107-
"disabling cudagraphs.")
99+
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
100+
"cudagraphs. Fallback to cudagraph_mode=NONE")
108101
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
109102

110103
# check and update parallel config
@@ -182,3 +175,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
182175
"Intel Arc A770 have bfloat16 accuracy known issue. "
183176
"You can use float16 instead by explicitly setting the "
184177
"`dtype` flag in CLI, for example: --dtype=half.")
178+
179+
@classmethod
180+
def opaque_attention_op(cls) -> bool:
181+
return True

0 commit comments

Comments
 (0)