Skip to content

Commit ac028fa

Browse files
committed
add opaque_attention_op interface in platform
Signed-off-by: Kunshang Ji <[email protected]>
1 parent ca400b8 commit ac028fa

File tree

6 files changed

+25
-2
lines changed

6 files changed

+25
-2
lines changed

vllm/attention/layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def __init__(
190190
# torch.compile works by registering the attention as one giant
191191
# opaque custom op. For other platforms, we directly call them
192192
# and let torch.compile handle them.
193-
self.use_direct_call = not current_platform.is_cuda_alike(
194-
) and not current_platform.is_cpu() and not current_platform.is_xpu()
193+
self.use_direct_call = not current_platform.opaque_attention_op()
195194

196195
self.use_output = self.attn_backend.accept_output_buffer
197196
compilation_config = get_current_vllm_config().compilation_config

vllm/platforms/cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,7 @@ def default_v1(cls, model_config) -> bool:
335335
return (cls.supports_v1(model_config)
336336
and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC,
337337
CpuArchEnum.ARM, CpuArchEnum.S390X))
338+
339+
@classmethod
340+
def opaque_attention_op(cls) -> bool:
341+
return True

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool:
442442
def use_custom_allreduce(cls) -> bool:
443443
return True
444444

445+
@classmethod
446+
def opaque_attention_op(cls) -> bool:
447+
return True
448+
445449
@classmethod
446450
def get_static_graph_wrapper_cls(cls) -> str:
447451
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

vllm/platforms/interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,14 @@ def use_custom_allreduce(cls) -> bool:
509509
"""
510510
return False
511511

512+
@classmethod
513+
def opaque_attention_op(cls) -> bool:
514+
"""
515+
Returns True if we register attention as one giant opaque custom op
516+
on the current platform
517+
"""
518+
return False
519+
512520
@classmethod
513521
def validate_request(
514522
cls,

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,10 @@ def use_custom_allreduce(cls) -> bool:
411411
supported_archs = ['gfx94', 'gfx95']
412412
return any(gfx in gcn_arch for gfx in supported_archs)
413413

414+
@classmethod
415+
def opaque_attention_op(cls) -> bool:
416+
return True
417+
414418
@classmethod
415419
def get_cu_count(cls, device_id: int = 0) -> int:
416420
return torch.cuda.get_device_properties(

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,7 @@ def get_global_graph_pool(self) -> Any:
181181
Currently xpu does NOT support Graph model.
182182
"""
183183
raise NotImplementedError("XPU does not support Graph model.")
184+
185+
@classmethod
186+
def opaque_attention_op(cls) -> bool:
187+
return True

0 commit comments

Comments
 (0)