Skip to content

Commit 2101882

Browse files
no-pad-fas3
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent f1d662c commit 2101882

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ else()
581581
FetchContent_Declare(
582582
vllm-flash-attn
583583
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
584-
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
584+
GIT_TAG 62cd67b571e806aa694a4c0f293d72a0f4717a97
585585
GIT_PROGRESS TRUE
586586
# Don't share the vllm-flash-attn build between build types
587587
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

vllm/attention/backends/mla/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
scaled_dequantize, scaled_quantize)
3030
from vllm.model_executor.layers.rotary_embedding import (
3131
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
32+
from vllm.platforms import current_platform
3233

3334
try:
3435
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -182,6 +183,10 @@ def __init__(
182183
self.kv_b_proj = kv_b_proj
183184
self.o_proj = o_proj
184185
self.vllm_flash_attn_version = get_flash_attn_version()
186+
# Currently different K headdim and V headdim is only supported for
187+
# hopper devices
188+
self.pad_v_head = not self.vllm_flash_attn_version >= 3 or \
189+
current_platform.get_device_capability()[0] != 9
185190

186191
def _v_up_proj_and_o_proj(self, x):
187192
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
@@ -501,11 +506,10 @@ def _forward_prefill_flash(
501506

502507
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
503508

504-
# For MLA the v head dim is smaller than qk head dim so we pad out
505-
# v with 0s to match the qk head dim
506509
v_dim = v.shape[-1]
507-
pad_v = self.vllm_flash_attn_version < 3
508-
if pad_v:
510+
if self.pad_v_head:
511+
# For MLA the v head dim is smaller than qk head dim so we pad out
512+
# v with 0s to match the qk head dim
509513
v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
510514
value=0)
511515

@@ -522,7 +526,7 @@ def _forward_prefill_flash(
522526
fa_version=self.vllm_flash_attn_version,
523527
)
524528

525-
if pad_v:
529+
if self.pad_v_head:
526530
attn_output = attn_output\
527531
.view(-1, self.num_heads, q.shape[-1])[..., :v_dim]
528532

0 commit comments

Comments
 (0)