Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 106 additions & 187 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,16 +793,14 @@ def _compute_prefill_context(
return prefix_output, prefix_lse

def _forward_prefill(
self,
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
self,
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1

num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
self.num_heads,
Expand All @@ -811,104 +809,55 @@ def _forward_prefill(
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
ascend_config = get_ascend_config()

if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
] and not ascend_config.chunked_prefill_for_mla:
attn_output_torch = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output_torch,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
elif attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
]:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch

return attn_output
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]

def exec_kv(
self,
Expand Down Expand Up @@ -983,102 +932,72 @@ def rope_single(

def _forward_decode(
self,
q_nope: torch.Tensor,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
block_size: int,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
Comment on lines 934 to 942
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of _forward_decode was changed to take block_size: int instead of kv_c_and_k_pe_cache. However, the call sites in the forward method are not updated and still pass kv_cache (a tuple of tensors), which will cause a TypeError. To fix this, you should revert the signature to accept kv_c_and_k_pe_cache and derive block_size inside this function, as it was done previously (block_size = kv_c_and_k_pe_cache[0].shape[1]).

Suggested change
self,
q_nope: torch.Tensor,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
block_size: int,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor, ...],
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:

decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
actual_seq_lengths = None
if self.enable_kv_nz:
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"

if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
actual_seq_lengths=actual_seq_lengths)

num_tokens = ql_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
if self.enable_kv_nz:
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"

if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
ql_nope = ql_nope.view(num_tokens, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len(kv_c_and_k_pe_cache) > 1
if envs.VLLM_ASCEND_MLA_PA:
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_pe, kv_c_and_k_pe_cache[0],
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
self.num_kv_heads)
if self.enable_kv_nz:
ql_nope = ql_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
k_cache = torch.cat(
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=k_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.
block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
ql_nope = ql_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

attn_output, _ = torch_npu.npu_fused_infer_attention_score(
ql_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
actual_seq_lengths=actual_seq_lengths)

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output,
Expand Down
Loading