Skip to content

Commit 98a3a81

Browse files
WoosukKwonLiuXiaoxuanPKUsimon-moheheda12345hongxiayang
authored
[ROCm] Add attention sink to use_rocm_custom_paged_attention (#22329)
Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: Minseok Lee <[email protected]> Co-authored-by: Yongye Zhu <[email protected]>
1 parent de98252 commit 98a3a81

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/platforms/rocm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention(
127127
max_seq_len: int,
128128
sliding_window: int,
129129
kv_cache_dtype: str,
130-
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
130+
alibi_slopes: Optional[torch.Tensor] = None,
131+
sinks: Optional[torch.Tensor] = None) -> bool:
131132

132133
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
133134
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention(
145146
and max_seq_len <= 128 * 1024
146147
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
147148
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
148-
and envs.VLLM_ROCM_USE_AITER))
149+
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
149150

150151
else:
151152
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
@@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention(
155156
and (gqa_ratio >= 3 and gqa_ratio <= 16)
156157
and max_seq_len <= 128 * 1024 and alibi_slopes is None
157158
and kv_cache_dtype == "auto"
158-
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
159+
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
159160

160161

161162
class RocmPlatform(Platform):
@@ -170,7 +171,7 @@ class RocmPlatform(Platform):
170171

171172
supported_quantization: list[str] = [
172173
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
173-
"quark", "ptpc_fp8"
174+
"quark", "ptpc_fp8", "mxfp4"
174175
]
175176

176177
@classmethod
@@ -469,4 +470,4 @@ def device_count(cls) -> int:
469470

470471
@classmethod
471472
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
472-
return True
473+
return True

0 commit comments

Comments
 (0)