Skip to content

Commit b6f01bd

Browse files
authored
refactor: abstract graph mode support into platform interface (#25161)
Signed-off-by: Yizhou Liu <[email protected]>
1 parent 4cf71cc commit b6f01bd

File tree

5 files changed

+23
-7
lines changed

5 files changed

+23
-7
lines changed

vllm/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def __post_init__(self):
503503
if self.compilation_config.pass_config.enable_sequence_parallelism:
504504
self.compilation_config.custom_ops.append("+rms_norm")
505505

506-
if current_platform.is_cuda_alike() or current_platform.is_xpu():
506+
if current_platform.support_static_graph_mode():
507507
# if cudagraph_mode is not explicitly set by users, set default
508508
# value
509509
if self.compilation_config.cudagraph_mode is None:

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,10 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
498498
def support_hybrid_kv_cache(cls) -> bool:
499499
return True
500500

501+
@classmethod
502+
def support_static_graph_mode(cls) -> bool:
503+
return True
504+
501505

502506
# NVML utils
503507
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,13 @@ def support_hybrid_kv_cache(cls) -> bool:
587587
"""
588588
return False
589589

590+
@classmethod
591+
def support_static_graph_mode(cls) -> bool:
592+
"""
593+
Returns if the graph mode is supported by the current platform.
594+
"""
595+
return False
596+
590597
@classmethod
591598
def use_sync_weight_loader(cls) -> bool:
592599
"""

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
477477
@classmethod
478478
def support_hybrid_kv_cache(cls) -> bool:
479479
return True
480+
481+
@classmethod
482+
def support_static_graph_mode(cls) -> bool:
483+
return True

vllm/platforms/xpu.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
113113
# lazy import to avoid circular import
114114
from vllm.config import CompilationLevel, CUDAGraphMode
115115
compilation_config = vllm_config.compilation_config
116-
if compilation_config.cudagraph_mode is None or \
117-
compilation_config.cudagraph_mode.max_cudagraph_mode() \
118-
!= CUDAGraphMode.NONE:
119-
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
120-
"cudagraphs. Fallback to cudagraph_mode=NONE")
121-
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
116+
117+
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \
118+
"CUDA graph mode should be NONE on XPU"
122119

123120
if vllm_config.lora_config is not None:
124121
compilation_config.level = CompilationLevel.NO_COMPILATION
@@ -169,6 +166,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
169166
def support_hybrid_kv_cache(cls) -> bool:
170167
return True
171168

169+
@classmethod
170+
def support_static_graph_mode(cls) -> bool:
171+
return False
172+
172173
@classmethod
173174
def is_pin_memory_available(cls):
174175
return True

0 commit comments

Comments
 (0)