Skip to content

Commit 4082338

Browse files
authored
Remove unneeded ROCm platform import when using CUDA (#22765)
Signed-off-by: mgoin <[email protected]>
1 parent c6b9287 commit 4082338

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2323
GroupShape)
2424
from vllm.platforms import current_platform
25-
from vllm.platforms.rocm import use_rocm_custom_paged_attention
2625

2726
if TYPE_CHECKING:
2827
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
@@ -886,6 +885,7 @@ def forward(
886885
num_seqs, num_heads, head_size = decode_query.shape
887886
block_size = value_cache.shape[3]
888887
gqa_ratio = num_heads // self.num_kv_heads
888+
from vllm.platforms.rocm import use_rocm_custom_paged_attention
889889
use_custom = use_rocm_custom_paged_attention(
890890
decode_query.dtype, head_size, block_size, gqa_ratio,
891891
decode_meta.max_decode_seq_len, self.sliding_window,

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from vllm import _custom_ops as ops
1313
from vllm.platforms import current_platform
14-
from vllm.platforms.rocm import use_rocm_custom_paged_attention
1514
from vllm.triton_utils import tl, triton
1615

1716
from .prefix_prefill import context_attention_fwd
@@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
296295
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
297296
16)
298297

298+
from vllm.platforms.rocm import use_rocm_custom_paged_attention
299299
use_custom = use_rocm_custom_paged_attention(
300300
query.dtype,
301301
head_size,

0 commit comments

Comments
 (0)