Skip to content

Commit 061c150

Browse files
committed
address comments
Signed-off-by: Kunshang Ji <[email protected]>
1 parent ff20735 commit 061c150

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

vllm/compilation/fix_functionalization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ class FixFunctionalizationPass(VllmInductorPass):
2727
"""
2828

2929
def __call__(self, graph: torch.fx.Graph):
30+
# XPU does not support auto-functionalization yet.
31+
# Will enable this when switch to vllm-xpu-kernels.
32+
if current_platform.is_xpu():
33+
logger.debug("XPU platform does not support fix functionality"
34+
"pass currently.")
35+
return
36+
3037
self.begin()
3138
self.dump_graph(graph, "before_fix_functionalization")
3239

3340
self.nodes_to_remove: list[torch.fx.Node] = []
3441
count = 0
3542
for node in graph.nodes:
36-
# XPU does not support auto-functionalization yet.
37-
# Will enable this when switch to vllm-xpu-kernels.
38-
if current_platform.is_xpu():
39-
continue
4043
if not is_func(node, auto_functionalized):
4144
continue # Avoid deep if-elif nesting
4245

vllm/platforms/xpu.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import os
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, Any, Optional
66

77
import torch
88

@@ -78,10 +78,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
7878
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
7979
return True
8080

81-
@classmethod
82-
def get_piecewise_backend_cls(cls) -> str:
83-
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
84-
8581
@classmethod
8682
def inference_mode(cls):
8783
return torch.no_grad()
@@ -201,3 +197,9 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
201197
@classmethod
202198
def device_count(cls) -> int:
203199
return torch.xpu.device_count()
200+
201+
def get_global_graph_pool(self) -> Any:
202+
"""
203+
Currently xpu does NOT support Graph model.
204+
"""
205+
return None

0 commit comments

Comments
 (0)