diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 497b7b53ab..b4d5e26629 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -728,3 +728,65 @@ def test_forward_without_graph(self, _, mock_forward_prefill): hidden_states_or_kv_c_normed, k_pe, kv_cache, metadata, output, False) self.assertEqual(result.shape[0], num_tokens) + + + @patch("vllm_ascend.attention.mla_v1.npu_prefetch") + def test_mla_preprocess(self, magic_npu_fetch): + # Mock 输入 + magic_npu_fetch.return_value = MagicMock() + batch_size = 4 + seq_len = 8 + hidden_size = 1024 + hidden_states = torch.randn(batch_size * seq_len, hidden_size) + + qk_nope_head_dim = 64 + qk_rope_head_dim = 128 + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + kv_lora_rank = 512 + kv_cache = MagicMock() + + # Mock attn_metadata + + attn_metadata = MagicMock() + attn_metadata.num_decodes = 2 + attn_metadata.num_prefills = 2 + attn_metadata.num_decode_tokens = 2 + attn_metadata.num_actual_tokens = 4 + num_prefill_tokens = 2 + attn_metadata.slot_mapping = torch.arange(4) + attn_metadata.decode.cos = torch.randn(2, 64) + attn_metadata.decode.sin = torch.randn(2, 64) + attn_metadata.prefill.cos = torch.randn(2, 64) + attn_metadata.prefill.sin = torch.randn(2, 64) + print(f">>>>>>>>>>>>>>>>>>num heads {self.impl.num_heads}>>>>>>>>>>>>>>>>>>") + + self.impl.q_a_proj = MagicMock() + self.impl.q_a_layernorm = MagicMock() + self.impl.q_a_layernorm.return_value = torch.randn(attn_metadata.num_actual_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim) + self.impl.kv_a_proj_with_mqa = MagicMock() + self.impl.kv_a_proj_with_mqa.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)] + self.impl.q_proj = MagicMock() + self.impl.q_proj.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_head_dim)] + self.impl.kv_b_proj = MagicMock() + self.impl.kv_b_proj.return_value = [torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.v_head_dim + self.impl.qk_nope_head_dim)] + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + self.impl.exec_kv_prefill = MagicMock() + self.impl.exec_kv_prefill.return_value = [torch.randn(num_prefill_tokens,self.impl.num_heads,self.impl.qk_rope_head_dim), torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.kv_lora_rank)] + self.impl._q_proj_and_k_up_proj = MagicMock() + self.impl._q_proj_and_k_up_proj.return_value = [MagicMock(), MagicMock()] + + + decode_res, prefill_res = self.impl._mla_preprocess( + hidden_states, + kv_cache, + attn_metadata, + need_gather_q_kv=False + ) + + self.assertIsNotNone(decode_res) + self.assertIsNotNone(prefill_res) + + + diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e7dccf33ab..f25c3a6abe 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -793,16 +793,14 @@ def _compute_prefill_context( return prefix_output, prefix_lse def _forward_prefill( - self, - query: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], - attn_metadata: AscendMLAMetadata, + self, + query: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None - assert len(kv_c_and_k_pe_cache) > 1 - num_tokens = query.size(0) attn_output = torch.empty(num_tokens, self.num_heads, @@ -811,104 +809,55 @@ def _forward_prefill( device=query.device) k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - ascend_config = get_ascend_config() - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output_torch = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output_torch, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_seq_lens, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) - elif attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ]: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=query.device) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, - 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) # 512: mask only support 512 + if attn_metadata.num_prefills > 1: + mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=mask, + seqlen=torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output = attn_output_torch - return attn_output + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self.o_proj(attn_output)[0] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self.o_proj(attn_output)[0] def exec_kv( self, @@ -983,102 +932,72 @@ def rope_single( def _forward_decode( self, - q_nope: torch.Tensor, + ql_nope: torch.Tensor, q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], + block_size: int, attn_metadata: AscendMLAMetadata, enable_multistream_mla: bool = False, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None - num_tokens = q_nope.size(0) - if self.running_in_graph or self.running_chunkprefilll_with_torchair: - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] - block_size = kv_c_and_k_pe_cache[0].shape[1] - actual_seq_lengths = None - if self.enable_kv_nz: - k_nope = k_nope.view(-1, self.num_kv_heads, - self.kv_lora_rank // 16, block_size, 16) - k_pe = k_pe.view(-1, self.num_kv_heads, - self.qk_rope_head_dim // 16, block_size, 16) - input_layout = "BSND" - else: - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) - input_layout = "BNSD" - - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 - input_layout = "TND" - # [bs * q_seq_len, num_heads_per_rank, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, -1) - sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore - actual_seq_lengths = decode_meta.actual_seq_lengths_q - else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - sparse_mode = 0 - spec_attn_mask = None - - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout=input_layout, - atten_mask=spec_attn_mask, - sparse_mode=sparse_mode, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=decode_meta.block_table, - block_size=block_size, - actual_seq_lengths_kv=decode_meta.seq_lens_list, - actual_seq_lengths=actual_seq_lengths) + + num_tokens = ql_nope.size(0) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + actual_seq_lengths = None + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + input_layout = "TND" + # [bs * q_seq_len, num_heads_per_rank, dim] + ql_nope = ql_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q else: - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become - # public available - assert len(kv_c_and_k_pe_cache) > 1 - if envs.VLLM_ASCEND_MLA_PA: - attn_output = torch_npu.atb.npu_multi_head_latent_attention( - q_nope, q_pe, kv_c_and_k_pe_cache[0], - kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, self.num_heads, self.scale, - self.num_kv_heads) + if self.enable_kv_nz: + ql_nope = ql_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) else: - q = torch.cat([q_nope, q_pe], dim=-1) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) - k_cache = torch.cat( - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=k_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode. - block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + ql_nope = ql_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + ql_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=decode_meta.block_table, + block_size=block_size, + actual_seq_lengths_kv=decode_meta.seq_lens_list, + actual_seq_lengths=actual_seq_lengths) + current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self._v_up_proj_and_o_proj(attn_output, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 7c0f77f4f8..7fb698383d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -509,3 +509,301 @@ def get_ascend_soc_version(): global _ascend_soc_version assert _ascend_soc_version is not None return _ascend_soc_version + + +from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +def _forward_prefill( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + value: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + print(">>>>>>>>>>>>>>>>>Entered New Forward Prefill>>>>>>>>>>>>>>>>>>>>") + assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 + query = torch.cat([q_nope, q_pe], dim=-1) + num_tokens = q_nope.size(0) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device=query.device) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache + ascend_config = get_ascend_config() + + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ] and not ascend_config.chunked_prefill_for_mla: # + + attn_output_torch = torch.empty(num_tokens, + self.num_heads * self.v_head_dim, + dtype=query.dtype, + device=query.device) + # current requests is chunked in prefill, disable flash attention with chunked prefill + vanilla_chunked_prefill_mla( + output=attn_output_torch, + query=query, + kv_cache=kv_c_and_k_pe_cache, + block_tables=attn_metadata.prefill.block_table, + query_lens=attn_metadata.prefill.query_lens, + context_lens=attn_metadata.prefill.context_lens, + kv_b_proj=self.kv_b_proj, + max_query_len=attn_metadata.prefill.max_query_len, + max_context_len=attn_metadata.prefill.max_seq_lens, + nope_dim=self.qk_nope_head_dim, + rope_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + scale=self.scale, + alibi_slopes=None, + causal=True) + elif attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ]: + query = torch.cat([q_nope, q_pe], dim=-1) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=q_nope.device) + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) # 512: mask only support 512 + if attn_metadata.num_prefills > 1: + mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=mask, + seqlen=torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + key = torch.cat((k_nope, k_pe), dim=-1) + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=attn_metadata.attn_mask, + seq_len=attn_metadata.prefill.context_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) + attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) + else: + raise RuntimeError( + "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" + ) + attn_output = attn_output.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit + ] and not ascend_config.chunked_prefill_for_mla: + attn_output = attn_output_torch + + return attn_output + + +self.chunked_prefill_for_mla = additional_config.get( + "chunked_prefill_for_mla", False) + + + @patch("vllm_ascend.attention.mla_v1.npu_prefetch") + def test_mla_preprocess(self, magic_npu_fetch): + magic_npu_fetch.return_value = MagicMock() + batch_size = 4 + seq_len = 8 + hidden_size = 1024 + hidden_states = torch.randn(batch_size * seq_len, hidden_size) + + qk_nope_head_dim = 64 + qk_rope_head_dim = 128 + qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + kv_lora_rank = 512 + kv_cache = MagicMock() + + attn_metadata = MagicMock() + attn_metadata.num_decodes = 2 + attn_metadata.num_prefills = 2 + attn_metadata.num_decode_tokens = 2 + attn_metadata.num_actual_tokens = 4 + num_prefill_tokens = 2 + attn_metadata.slot_mapping = torch.arange(4) + attn_metadata.decode.cos = torch.randn(2, 64) + attn_metadata.decode.sin = torch.randn(2, 64) + attn_metadata.prefill.cos = torch.randn(2, 64) + attn_metadata.prefill.sin = torch.randn(2, 64) + + self.impl.q_a_proj = MagicMock() + self.impl.q_a_layernorm = MagicMock() + self.impl.q_a_layernorm.return_value = torch.randn( + attn_metadata.num_actual_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim) + self.impl.kv_a_proj_with_mqa = MagicMock() + self.impl.kv_a_proj_with_mqa.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_nope_head_dim + self.impl.kv_lora_rank) + ] + self.impl.q_proj = MagicMock() + self.impl.q_proj.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_head_dim) + ] + self.impl.kv_b_proj = MagicMock() + self.impl.kv_b_proj.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.v_head_dim + self.impl.qk_nope_head_dim) + ] + self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x) + self.impl.exec_kv_decode = MagicMock() + self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()] + self.impl.exec_kv_prefill = MagicMock() + self.impl.exec_kv_prefill.return_value = [ + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim), + torch.randn(num_prefill_tokens, self.impl.num_heads, + self.impl.kv_lora_rank) + ] + self.impl._q_proj_and_k_up_proj = MagicMock() + self.impl._q_proj_and_k_up_proj.return_value = [ + MagicMock(), MagicMock() + ] + self.impl.num_kv_heads = self.impl.num_heads + + decode_res, prefill_res = self.impl._mla_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False) + + self.assertIsNotNone(decode_res) + self.assertIsNotNone(prefill_res) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache): + B = 2 + N = self.impl.num_kv_heads + D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim + kv_no_split = torch.randn(B, N, D) + self.impl.enable_kv_nz = None + self.impl.kv_a_layernorm.weight = MagicMock() + self.impl.kv_a_layernorm.variance_epsilon = MagicMock() + cos = MagicMock() + sin = MagicMock() + slots = MagicMock() + kv_cache = [MagicMock(), MagicMock()] + + mock_kv_rmsnorm_rope_cache.return_value = [ + None, None, + torch.randn(B, N, 1, self.impl.qk_rope_head_dim), + torch.randn(B, N, 1, self.impl.kv_lora_rank) + ] + + k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin, + kv_cache, slots) + + self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) + self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) + + @patch("torch_npu.npu_kv_rmsnorm_rope_cache") + def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache): + B = 2 + N = self.impl.num_kv_heads + D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim + kv_no_split = torch.randn(B, N, D) + self.impl.enable_kv_nz = None + self.impl.kv_a_layernorm.weight = MagicMock() + self.impl.kv_a_layernorm.variance_epsilon = MagicMock() + cos = MagicMock() + sin = MagicMock() + slots = MagicMock() + kv_cache = [MagicMock(), MagicMock()] + + mock_kv_rmsnorm_rope_cache.return_value = [ + torch.randn(B, N, 1, self.impl.qk_rope_head_dim), + torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None + ] + + k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin, + kv_cache, slots) + + self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim) + self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) + + @patch("torch.npu.stream") + @patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_forward_decode(self, mock_npu_fused_infer_attention_score, + mock_get_multistream_comm_context, + mock_npu_stream): + B = 2 + N = self.impl.num_kv_heads + BS = 100 + HD = self.impl.v_head_dim + self.impl.kv_lora_rank = 256 + self.impl.spec_token_num = 1 + self.impl._v_up_proj = MagicMock() + self.impl._v_up_proj.return_value = torch.randn(B, N, HD) + q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) + q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) + k_nope = torch.randn(BS, N, self.impl.kv_lora_rank) + k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim) + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.SpecDecoding + attn_metadata.decode = MagicMock() + attn_metadata.decode.actual_seq_lengths_q = MagicMock() + attn_metadata.decode.seq_lens_list = MagicMock() + self.impl.enable_kv_nz = True + + mock_npu_fused_infer_attention_score.return_value = [ + torch.randn(B, N, self.impl.kv_lora_rank), None + ] + mock_get_multistream_comm_context.return_value = None + + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, + attn_metadata) + + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], HD) + + self.impl.enable_kv_nz = False + attn_metadata.attn_state = None + mock_return_value = MagicMock() + mock_get_multistream_comm_context.return_value = mock_return_value + mock_return_value.before_comm_event = MagicMock() + mock_return_value.comm_stream = MagicMock() + mock_npu_stream.return_value = MagicMock() + + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, + attn_metadata) + + self.assertEqual(result.shape[0], B) + self.assertEqual(result.shape[1], N) + self.assertEqual(result.shape[2], HD) + +