Skip to content

Commit 17edd8a

Browse files
authored
[Platform][Kernel] platform-specific kernel loading (#25823)
Signed-off-by: Hank <[email protected]>
1 parent 3303cfb commit 17edd8a

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import contextlib
54
from typing import TYPE_CHECKING, Optional, Union
65

76
import torch
@@ -13,16 +12,8 @@
1312

1413
logger = init_logger(__name__)
1514

16-
if not current_platform.is_tpu() and not current_platform.is_xpu():
17-
try:
18-
import vllm._C
19-
except ImportError as e:
20-
logger.warning("Failed to import from vllm._C with %r", e)
21-
22-
supports_moe_ops = False
23-
with contextlib.suppress(ImportError):
24-
import vllm._moe_C # noqa: F401
25-
supports_moe_ops = True
15+
current_platform.import_core_kernels()
16+
supports_moe_ops = current_platform.try_import_moe_kernels()
2617

2718
if TYPE_CHECKING:
2819

vllm/platforms/interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import contextlib
34
import enum
45
import os
56
import platform
@@ -163,6 +164,22 @@ def device_id_to_physical_device_id(cls, device_id: int):
163164
else:
164165
return device_id
165166

167+
@classmethod
168+
def import_core_kernels(cls) -> None:
169+
""" Import any platform-specific C kernels. """
170+
try:
171+
import vllm._C # noqa: F401
172+
except ImportError as e:
173+
logger.warning("Failed to import from vllm._C: %r", e)
174+
175+
@classmethod
176+
def try_import_moe_kernels(cls) -> bool:
177+
""" Import any platform-specific MoE kernels. """
178+
with contextlib.suppress(ImportError):
179+
import vllm._moe_C # noqa: F401
180+
return True
181+
return False
182+
166183
@classmethod
167184
def get_vit_attn_backend(cls, head_size: int,
168185
dtype: torch.dtype) -> "_Backend":

vllm/platforms/tpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class TpuPlatform(Platform):
4747
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
4848
]
4949

50+
@classmethod
51+
def import_core_kernels(cls) -> None:
52+
pass
53+
5054
@classmethod
5155
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
5256
dtype: torch.dtype, kv_cache_dtype: Optional[str],

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class XPUPlatform(Platform):
3434
dist_backend: str = "ccl" # ccl | xccl
3535
device_control_env_var: str = "ZE_AFFINITY_MASK"
3636

37+
@classmethod
38+
def import_core_kernels(cls) -> None:
39+
pass
40+
3741
@classmethod
3842
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
3943
dtype: torch.dtype, kv_cache_dtype: Optional[str],

0 commit comments

Comments
 (0)