Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,65 @@ def test_forward_without_graph(self, _, mock_forward_prefill):
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)


@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
# Mock 输入
magic_npu_fetch.return_value = MagicMock()
batch_size = 4
seq_len = 8
hidden_size = 1024
hidden_states = torch.randn(batch_size * seq_len, hidden_size)

qk_nope_head_dim = 64
qk_rope_head_dim = 128
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
kv_lora_rank = 512
kv_cache = MagicMock()

# Mock attn_metadata

attn_metadata = MagicMock()
attn_metadata.num_decodes = 2
attn_metadata.num_prefills = 2
attn_metadata.num_decode_tokens = 2
attn_metadata.num_actual_tokens = 4
num_prefill_tokens = 2
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.decode.cos = torch.randn(2, 64)
attn_metadata.decode.sin = torch.randn(2, 64)
attn_metadata.prefill.cos = torch.randn(2, 64)
attn_metadata.prefill.sin = torch.randn(2, 64)
print(f">>>>>>>>>>>>>>>>>>num heads {self.impl.num_heads}>>>>>>>>>>>>>>>>>>")

self.impl.q_a_proj = MagicMock()
self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(attn_metadata.num_actual_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)]
self.impl.q_proj = MagicMock()
self.impl.q_proj.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_head_dim)]
self.impl.kv_b_proj = MagicMock()
self.impl.kv_b_proj.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.v_head_dim + self.impl.qk_nope_head_dim)]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl.exec_kv_prefill = MagicMock()
self.impl.exec_kv_prefill.return_value = [torch.randn(num_prefill_tokens,self.impl.num_heads,self.impl.qk_rope_head_dim), torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.kv_lora_rank)]
self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [MagicMock(), MagicMock()]


decode_res, prefill_res = self.impl._mla_preprocess(
hidden_states,
kv_cache,
attn_metadata,
need_gather_q_kv=False
)

self.assertIsNotNone(decode_res)
self.assertIsNotNone(prefill_res)



293 changes: 106 additions & 187 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,16 +793,14 @@ def _compute_prefill_context(
return prefix_output, prefix_lse

def _forward_prefill(
self,
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
self,
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1

num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
self.num_heads,
Expand All @@ -811,104 +809,55 @@ def _forward_prefill(
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
ascend_config = get_ascend_config()

if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
] and not ascend_config.chunked_prefill_for_mla:
attn_output_torch = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output_torch,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
elif attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
]:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch

return attn_output
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]

def exec_kv(
self,
Expand Down Expand Up @@ -983,102 +932,72 @@ def rope_single(

def _forward_decode(
self,
q_nope: torch.Tensor,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
block_size: int,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
Comment on lines 934 to 942
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of _forward_decode was changed to take block_size: int instead of kv_c_and_k_pe_cache. However, the call sites in the forward method are not updated and still pass kv_cache (a tuple of tensors), which will cause a TypeError. To fix this, you should revert the signature to accept kv_c_and_k_pe_cache and derive block_size inside this function, as it was done previously (block_size = kv_c_and_k_pe_cache[0].shape[1]).

Suggested change
self,
q_nope: torch.Tensor,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
block_size: int,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor, ...],
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:

decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
actual_seq_lengths = None
if self.enable_kv_nz:
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"

if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
actual_seq_lengths=actual_seq_lengths)

num_tokens = ql_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
if self.enable_kv_nz:
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"

if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
ql_nope = ql_nope.view(num_tokens, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len(kv_c_and_k_pe_cache) > 1
if envs.VLLM_ASCEND_MLA_PA:
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_pe, kv_c_and_k_pe_cache[0],
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
self.num_kv_heads)
if self.enable_kv_nz:
ql_nope = ql_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
k_cache = torch.cat(
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=k_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.
block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
ql_nope = ql_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

attn_output, _ = torch_npu.npu_fused_infer_attention_score(
ql_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
actual_seq_lengths=actual_seq_lengths)

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
Expand Down
Loading
Loading