|
4 | 4 | import numpy as np
|
5 | 5 | import torch
|
6 | 6 | import torch_npu
|
7 |
| -from torch import nn |
8 |
| -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, |
| 7 | +from vllm.attention.backends.abstract import (AttentionBackend, |
9 | 8 | AttentionMetadata,
|
10 | 9 | MLAAttentionImpl)
|
11 | 10 | from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
15 | 14 | UnquantizedLinearMethod)
|
16 | 15 | from vllm.utils import cdiv, round_down
|
17 | 16 |
|
18 |
| -import vllm_ascend.envs as envs_ascend |
19 | 17 | from vllm_ascend.ascend_config import get_ascend_config
|
20 | 18 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
21 | 19 | from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
22 | 20 | from vllm_ascend.multistream.context import get_multistream_comm_context
|
23 | 21 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
24 |
| -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla |
25 |
| -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor |
26 | 22 | from vllm_ascend.utils import npu_prefetch
|
27 | 23 | from vllm_ascend.worker.npu_input_batch import InputBatch
|
28 | 24 |
|
@@ -812,101 +808,43 @@ def _forward_prefill(
|
812 | 808 | -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
813 | 809 | [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
814 | 810 | k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
815 |
| - # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache |
816 |
| - ascend_config = get_ascend_config() |
| 811 | + attn_lse = torch.empty(self.num_heads, |
| 812 | + num_tokens, |
| 813 | + dtype=torch.float32, |
| 814 | + device=query.device) |
| 815 | + q_pe = query[..., self.qk_nope_head_dim:] |
| 816 | + q_nope = query[..., :self.qk_nope_head_dim] |
| 817 | + mask = torch.triu( |
| 818 | + torch.ones(512, 512, device=query.device, dtype=query.dtype), |
| 819 | + 1) # 512: mask only support 512 |
| 820 | + if attn_metadata.num_prefills > 1: |
| 821 | + mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, |
| 822 | + 1) |
| 823 | + torch_npu.atb.npu_ring_mla( |
| 824 | + q_nope=q_nope, |
| 825 | + q_rope=q_pe, |
| 826 | + k_nope=k_nope, |
| 827 | + k_rope=k_pe, |
| 828 | + value=value, |
| 829 | + mask=mask, |
| 830 | + seqlen=torch.tensor(attn_metadata.prefill.query_lens, |
| 831 | + dtype=torch.int32), |
| 832 | + head_num=self.num_heads, |
| 833 | + kv_head_num=self.num_heads, |
| 834 | + pre_out=None, |
| 835 | + prev_lse=None, |
| 836 | + qk_scale=self.scale, |
| 837 | + kernel_type="kernel_type_high_precision", |
| 838 | + mask_type="mask_type_triu", |
| 839 | + input_layout="type_bsnd", |
| 840 | + calc_type="calc_type_first_ring", |
| 841 | + output=attn_output, |
| 842 | + softmax_lse=attn_lse) |
| 843 | + attn_output, attn_lse = self._compute_prefill_context( \ |
| 844 | + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) |
817 | 845 |
|
818 |
| - if attn_metadata.attn_state in [ |
819 |
| - AscendAttentionState.ChunkedPrefill, |
820 |
| - AscendAttentionState.SpecDecoding, |
821 |
| - AscendAttentionState.PrefillCacheHit |
822 |
| - ] and not ascend_config.chunked_prefill_for_mla: |
823 |
| - attn_output_torch = torch.empty(num_tokens, |
824 |
| - self.num_heads * self.v_head_dim, |
825 |
| - dtype=query.dtype, |
826 |
| - device=query.device) |
827 |
| - # current requests is chunked in prefill, disable flash attention with chunked prefill |
828 |
| - vanilla_chunked_prefill_mla( |
829 |
| - output=attn_output_torch, |
830 |
| - query=query, |
831 |
| - kv_cache=kv_c_and_k_pe_cache, |
832 |
| - block_tables=attn_metadata.prefill.block_table, |
833 |
| - query_lens=attn_metadata.prefill.query_lens, |
834 |
| - context_lens=attn_metadata.prefill.context_lens, |
835 |
| - kv_b_proj=self.kv_b_proj, |
836 |
| - max_query_len=attn_metadata.prefill.max_query_len, |
837 |
| - max_context_len=attn_metadata.prefill.max_seq_lens, |
838 |
| - nope_dim=self.qk_nope_head_dim, |
839 |
| - rope_dim=self.qk_rope_head_dim, |
840 |
| - v_head_dim=self.v_head_dim, |
841 |
| - scale=self.scale, |
842 |
| - alibi_slopes=None, |
843 |
| - causal=True) |
844 |
| - elif attn_metadata.attn_state in [ |
845 |
| - AscendAttentionState.ChunkedPrefill, |
846 |
| - AscendAttentionState.SpecDecoding, |
847 |
| - AscendAttentionState.PrefillCacheHit |
848 |
| - ]: |
849 |
| - attn_lse = torch.empty(self.num_heads, |
850 |
| - num_tokens, |
851 |
| - dtype=torch.float32, |
852 |
| - device=query.device) |
853 |
| - q_pe = query[..., self.qk_nope_head_dim:] |
854 |
| - q_nope = query[..., :self.qk_nope_head_dim] |
855 |
| - mask = torch.triu( |
856 |
| - torch.ones(512, 512, device=query.device, dtype=query.dtype), |
857 |
| - 1) # 512: mask only support 512 |
858 |
| - if attn_metadata.num_prefills > 1: |
859 |
| - mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, |
860 |
| - 1) |
861 |
| - torch_npu.atb.npu_ring_mla( |
862 |
| - q_nope=q_nope, |
863 |
| - q_rope=q_pe, |
864 |
| - k_nope=k_nope, |
865 |
| - k_rope=k_pe, |
866 |
| - value=value, |
867 |
| - mask=mask, |
868 |
| - seqlen=torch.tensor(attn_metadata.prefill.query_lens, |
869 |
| - dtype=torch.int32), |
870 |
| - head_num=self.num_heads, |
871 |
| - kv_head_num=self.num_heads, |
872 |
| - pre_out=None, |
873 |
| - prev_lse=None, |
874 |
| - qk_scale=self.scale, |
875 |
| - kernel_type="kernel_type_high_precision", |
876 |
| - mask_type="mask_type_triu", |
877 |
| - input_layout="type_bsnd", |
878 |
| - calc_type="calc_type_first_ring", |
879 |
| - output=attn_output, |
880 |
| - softmax_lse=attn_lse) |
881 |
| - attn_output, attn_lse = self._compute_prefill_context( \ |
882 |
| - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) |
883 |
| - |
884 |
| - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: |
885 |
| - key = torch.cat((k_nope, k_pe), dim=-1) |
886 |
| - torch_npu._npu_flash_attention( |
887 |
| - query=query, |
888 |
| - key=key, |
889 |
| - value=value, |
890 |
| - mask=attn_metadata.attn_mask, |
891 |
| - seq_len=attn_metadata.prefill.context_lens, |
892 |
| - scale_value=self.scale, |
893 |
| - num_heads=self.num_heads, |
894 |
| - num_kv_heads=self.num_heads, |
895 |
| - out=attn_output) |
896 |
| - attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) |
897 |
| - else: |
898 |
| - raise RuntimeError( |
899 |
| - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" |
900 |
| - ) |
901 | 846 | attn_output = attn_output.reshape(
|
902 | 847 | [num_tokens, self.num_heads * self.v_head_dim])
|
903 |
| - if attn_metadata.attn_state in [ |
904 |
| - AscendAttentionState.ChunkedPrefill, |
905 |
| - AscendAttentionState.SpecDecoding, |
906 |
| - AscendAttentionState.PrefillCacheHit |
907 |
| - ] and not ascend_config.chunked_prefill_for_mla: |
908 |
| - attn_output = attn_output_torch |
909 |
| - |
910 | 848 | return attn_output
|
911 | 849 |
|
912 | 850 | def exec_kv(
|
@@ -984,102 +922,69 @@ def _forward_decode(
|
984 | 922 | q_pe: torch.Tensor,
|
985 | 923 | k_nope: torch.Tensor,
|
986 | 924 | k_pe: torch.Tensor,
|
987 |
| - kv_c_and_k_pe_cache: Tuple[torch.Tensor], |
| 925 | + block_size: int, |
988 | 926 | attn_metadata: AscendMLAMetadata,
|
989 |
| - enable_multistream_mla: bool = False, |
990 | 927 | ) -> torch.Tensor:
|
991 | 928 | decode_meta = attn_metadata.decode
|
992 | 929 | assert decode_meta is not None
|
993 | 930 | num_tokens = q_nope.size(0)
|
994 |
| - if self.running_in_graph or self.running_chunkprefilll_with_torchair: |
995 |
| - # shape of knope/k_pe for npu graph mode should be: |
996 |
| - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] |
997 |
| - block_size = kv_c_and_k_pe_cache[0].shape[1] |
998 |
| - actual_seq_lengths = None |
999 |
| - if self.enable_kv_nz: |
1000 |
| - k_nope = k_nope.view(-1, self.num_kv_heads, |
1001 |
| - self.kv_lora_rank // 16, block_size, 16) |
1002 |
| - k_pe = k_pe.view(-1, self.num_kv_heads, |
1003 |
| - self.qk_rope_head_dim // 16, block_size, 16) |
1004 |
| - input_layout = "BSND" |
1005 |
| - else: |
1006 |
| - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, |
1007 |
| - self.kv_lora_rank) |
1008 |
| - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, |
1009 |
| - self.qk_rope_head_dim) |
1010 |
| - input_layout = "BNSD" |
1011 |
| - |
1012 |
| - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: |
1013 |
| - assert num_tokens % self.spec_token_num == 0 |
1014 |
| - input_layout = "TND" |
1015 |
| - # [bs * q_seq_len, num_heads_per_rank, dim] |
1016 |
| - q_nope = q_nope.view(num_tokens, self.num_heads, -1) |
1017 |
| - q_pe = q_pe.view(num_tokens, self.num_heads, -1) |
1018 |
| - sparse_mode = 3 |
1019 |
| - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore |
1020 |
| - actual_seq_lengths = decode_meta.actual_seq_lengths_q |
1021 |
| - else: |
1022 |
| - if self.enable_kv_nz: |
1023 |
| - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) |
1024 |
| - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) |
1025 |
| - else: |
1026 |
| - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) |
1027 |
| - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) |
1028 |
| - sparse_mode = 0 |
1029 |
| - spec_attn_mask = None |
1030 |
| - |
1031 |
| - attn_output, _ = torch_npu.npu_fused_infer_attention_score( |
1032 |
| - q_nope, |
1033 |
| - k_nope, |
1034 |
| - k_nope, |
1035 |
| - query_rope=q_pe, |
1036 |
| - key_rope=k_pe, |
1037 |
| - num_heads=self.num_heads, |
1038 |
| - num_key_value_heads=self.num_kv_heads, |
1039 |
| - input_layout=input_layout, |
1040 |
| - atten_mask=spec_attn_mask, |
1041 |
| - sparse_mode=sparse_mode, |
1042 |
| - scale=self.scale, |
1043 |
| - antiquant_mode=0, |
1044 |
| - antiquant_scale=None, |
1045 |
| - block_table=decode_meta.block_table, |
1046 |
| - block_size=block_size, |
1047 |
| - actual_seq_lengths_kv=decode_meta.seq_lens_list, |
1048 |
| - actual_seq_lengths=actual_seq_lengths) |
| 931 | + # shape of knope/k_pe for npu graph mode should be: |
| 932 | + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] |
| 933 | + actual_seq_lengths = None |
| 934 | + if self.enable_kv_nz: |
| 935 | + k_nope = k_nope.view(-1, self.num_kv_heads, |
| 936 | + self.kv_lora_rank // 16, block_size, 16) |
| 937 | + k_pe = k_pe.view(-1, self.num_kv_heads, |
| 938 | + self.qk_rope_head_dim // 16, block_size, 16) |
| 939 | + input_layout = "BSND" |
1049 | 940 | else:
|
1050 |
| - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will |
1051 |
| - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become |
1052 |
| - # public available |
1053 |
| - assert len(kv_c_and_k_pe_cache) > 1 |
1054 |
| - if envs_ascend.VLLM_ASCEND_MLA_PA: |
1055 |
| - attn_output = torch_npu.atb.npu_multi_head_latent_attention( |
1056 |
| - q_nope, q_pe, kv_c_and_k_pe_cache[0], |
1057 |
| - kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, |
1058 |
| - attn_metadata.decode.seq_lens, self.num_heads, self.scale, |
1059 |
| - self.num_kv_heads) |
| 941 | + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, |
| 942 | + self.kv_lora_rank) |
| 943 | + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, |
| 944 | + self.qk_rope_head_dim) |
| 945 | + input_layout = "BNSD" |
| 946 | + |
| 947 | + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: |
| 948 | + assert num_tokens % self.spec_token_num == 0 |
| 949 | + input_layout = "TND" |
| 950 | + # [bs * q_seq_len, num_heads_per_rank, dim] |
| 951 | + q_nope = q_nope.view(num_tokens, self.num_heads, -1) |
| 952 | + q_pe = q_pe.view(num_tokens, self.num_heads, -1) |
| 953 | + sparse_mode = 3 |
| 954 | + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore |
| 955 | + actual_seq_lengths = decode_meta.actual_seq_lengths_q |
| 956 | + else: |
| 957 | + if self.enable_kv_nz: |
| 958 | + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) |
| 959 | + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) |
1060 | 960 | else:
|
1061 |
| - q = torch.cat([q_nope, q_pe], dim=-1) |
1062 |
| - attn_output = torch.empty( |
1063 |
| - [num_tokens, self.num_heads, self.kv_lora_rank], |
1064 |
| - dtype=q.dtype, |
1065 |
| - device=q.device) |
1066 |
| - k_cache = torch.cat( |
1067 |
| - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) |
1068 |
| - torch_npu._npu_paged_attention_mla( |
1069 |
| - query=q, |
1070 |
| - key_cache=k_cache, |
1071 |
| - num_kv_heads=self.num_kv_heads, |
1072 |
| - num_heads=self.num_heads, |
1073 |
| - scale_value=self.scale, |
1074 |
| - block_table=attn_metadata.decode. |
1075 |
| - block_table, # type:ignore |
1076 |
| - context_lens=attn_metadata.decode.seq_lens, # type:ignore |
1077 |
| - mla_vheadsize=self.kv_lora_rank, |
1078 |
| - out=attn_output) |
| 961 | + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) |
| 962 | + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) |
| 963 | + sparse_mode = 0 |
| 964 | + spec_attn_mask = None |
| 965 | + |
| 966 | + attn_output, _ = torch_npu.npu_fused_infer_attention_score( |
| 967 | + q_nope, |
| 968 | + k_nope, |
| 969 | + k_nope, |
| 970 | + query_rope=q_pe, |
| 971 | + key_rope=k_pe, |
| 972 | + num_heads=self.num_heads, |
| 973 | + num_key_value_heads=self.num_kv_heads, |
| 974 | + input_layout=input_layout, |
| 975 | + atten_mask=spec_attn_mask, |
| 976 | + sparse_mode=sparse_mode, |
| 977 | + scale=self.scale, |
| 978 | + antiquant_mode=0, |
| 979 | + antiquant_scale=None, |
| 980 | + block_table=decode_meta.block_table, |
| 981 | + block_size=block_size, |
| 982 | + actual_seq_lengths_kv=decode_meta.seq_lens_list, |
| 983 | + actual_seq_lengths=actual_seq_lengths) |
| 984 | + |
1079 | 985 | current_ms_metadata = get_multistream_comm_context()
|
1080 | 986 | if current_ms_metadata is None:
|
1081 |
| - return self._v_up_proj(attn_output, |
1082 |
| - enable_multistream_mla) |
| 987 | + return self._v_up_proj(attn_output) |
1083 | 988 | else:
|
1084 | 989 | current_ms_metadata.before_comm_event.record()
|
1085 | 990 | with torch.npu.stream(current_ms_metadata.comm_stream):
|
|
0 commit comments