@@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention(
127
127
max_seq_len : int ,
128
128
sliding_window : int ,
129
129
kv_cache_dtype : str ,
130
- alibi_slopes : Optional [torch .Tensor ] = None ) -> bool :
130
+ alibi_slopes : Optional [torch .Tensor ] = None ,
131
+ sinks : Optional [torch .Tensor ] = None ) -> bool :
131
132
132
133
GPU_ARCH = torch .cuda .get_device_properties ("cuda" ).gcnArchName
133
134
ON_GFX9 = any (arch in GPU_ARCH for arch in ["gfx90a" , "gfx942" , "gfx950" ])
@@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention(
145
146
and max_seq_len <= 128 * 1024
146
147
and (envs .VLLM_ROCM_CUSTOM_PAGED_ATTN )
147
148
and not (envs .VLLM_ROCM_USE_AITER_PAGED_ATTN
148
- and envs .VLLM_ROCM_USE_AITER ))
149
+ and envs .VLLM_ROCM_USE_AITER ) and sinks is None )
149
150
150
151
else :
151
152
return (ON_GFX11_GFX12 and (not envs .VLLM_USE_V1 or sliding_window == 0
@@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention(
155
156
and (gqa_ratio >= 3 and gqa_ratio <= 16 )
156
157
and max_seq_len <= 128 * 1024 and alibi_slopes is None
157
158
and kv_cache_dtype == "auto"
158
- and envs .VLLM_ROCM_CUSTOM_PAGED_ATTN )
159
+ and envs .VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None )
159
160
160
161
161
162
class RocmPlatform (Platform ):
@@ -170,7 +171,7 @@ class RocmPlatform(Platform):
170
171
171
172
supported_quantization : list [str ] = [
172
173
"awq" , "gptq" , "fp8" , "compressed-tensors" , "fbgemm_fp8" , "gguf" ,
173
- "quark" , "ptpc_fp8"
174
+ "quark" , "ptpc_fp8" , "mxfp4"
174
175
]
175
176
176
177
@classmethod
@@ -469,4 +470,4 @@ def device_count(cls) -> int:
469
470
470
471
@classmethod
471
472
def is_kv_cache_dtype_supported (cls , kv_cache_dtype : str ) -> bool :
472
- return True
473
+ return True
0 commit comments