2929 scaled_dequantize , scaled_quantize )
3030from vllm .model_executor .layers .rotary_embedding import (
3131 DeepseekScalingRotaryEmbedding , RotaryEmbedding )
32+ from vllm .platforms import current_platform
3233
3334try :
3435 from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -182,6 +183,10 @@ def __init__(
182183 self .kv_b_proj = kv_b_proj
183184 self .o_proj = o_proj
184185 self .vllm_flash_attn_version = get_flash_attn_version ()
186+ # Currently different K headdim and V headdim is only supported for
187+ # hopper devices
188+ self .pad_v_head = not self .vllm_flash_attn_version >= 3 or \
189+ current_platform .get_device_capability ()[0 ] != 9
185190
186191 def _v_up_proj_and_o_proj (self , x ):
187192 if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
@@ -501,11 +506,10 @@ def _forward_prefill_flash(
501506
502507 k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))), dim = - 1 )
503508
504- # For MLA the v head dim is smaller than qk head dim so we pad out
505- # v with 0s to match the qk head dim
506509 v_dim = v .shape [- 1 ]
507- pad_v = self .vllm_flash_attn_version < 3
508- if pad_v :
510+ if self .pad_v_head :
511+ # For MLA the v head dim is smaller than qk head dim so we pad out
512+ # v with 0s to match the qk head dim
509513 v = torch .nn .functional .pad (v , [0 , q .shape [- 1 ] - v .shape [- 1 ]],
510514 value = 0 )
511515
@@ -522,7 +526,7 @@ def _forward_prefill_flash(
522526 fa_version = self .vllm_flash_attn_version ,
523527 )
524528
525- if pad_v :
529+ if self . pad_v_head :
526530 attn_output = attn_output \
527531 .view (- 1 , self .num_heads , q .shape [- 1 ])[..., :v_dim ]
528532
0 commit comments