@@ -812,101 +812,43 @@ def _forward_prefill(
812
812
- 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim ).split (
813
813
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
814
814
k_pe = k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))
815
- # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
816
- ascend_config = get_ascend_config ()
815
+ attn_lse = torch .empty (self .num_heads ,
816
+ num_tokens ,
817
+ dtype = torch .float32 ,
818
+ device = query .device )
819
+ q_pe = query [..., self .qk_nope_head_dim :]
820
+ q_nope = query [..., :self .qk_nope_head_dim ]
821
+ mask = torch .triu (
822
+ torch .ones (512 , 512 , device = query .device , dtype = query .dtype ),
823
+ 1 ) # 512: mask only support 512
824
+ if attn_metadata .num_prefills > 1 :
825
+ mask = mask .unsqueeze (0 ).repeat (attn_metadata .num_prefills , 1 ,
826
+ 1 )
827
+ torch_npu .atb .npu_ring_mla (
828
+ q_nope = q_nope ,
829
+ q_rope = q_pe ,
830
+ k_nope = k_nope ,
831
+ k_rope = k_pe ,
832
+ value = value ,
833
+ mask = mask ,
834
+ seqlen = torch .tensor (attn_metadata .prefill .query_lens ,
835
+ dtype = torch .int32 ),
836
+ head_num = self .num_heads ,
837
+ kv_head_num = self .num_heads ,
838
+ pre_out = None ,
839
+ prev_lse = None ,
840
+ qk_scale = self .scale ,
841
+ kernel_type = "kernel_type_high_precision" ,
842
+ mask_type = "mask_type_triu" ,
843
+ input_layout = "type_bsnd" ,
844
+ calc_type = "calc_type_first_ring" ,
845
+ output = attn_output ,
846
+ softmax_lse = attn_lse )
847
+ attn_output , attn_lse = self ._compute_prefill_context ( \
848
+ query , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
817
849
818
- if attn_metadata .attn_state in [
819
- AscendAttentionState .ChunkedPrefill ,
820
- AscendAttentionState .SpecDecoding ,
821
- AscendAttentionState .PrefillCacheHit
822
- ] and not ascend_config .chunked_prefill_for_mla :
823
- attn_output_torch = torch .empty (num_tokens ,
824
- self .num_heads * self .v_head_dim ,
825
- dtype = query .dtype ,
826
- device = query .device )
827
- # current requests is chunked in prefill, disable flash attention with chunked prefill
828
- vanilla_chunked_prefill_mla (
829
- output = attn_output_torch ,
830
- query = query ,
831
- kv_cache = kv_c_and_k_pe_cache ,
832
- block_tables = attn_metadata .prefill .block_table ,
833
- query_lens = attn_metadata .prefill .query_lens ,
834
- context_lens = attn_metadata .prefill .context_lens ,
835
- kv_b_proj = self .kv_b_proj ,
836
- max_query_len = attn_metadata .prefill .max_query_len ,
837
- max_context_len = attn_metadata .prefill .max_seq_lens ,
838
- nope_dim = self .qk_nope_head_dim ,
839
- rope_dim = self .qk_rope_head_dim ,
840
- v_head_dim = self .v_head_dim ,
841
- scale = self .scale ,
842
- alibi_slopes = None ,
843
- causal = True )
844
- elif attn_metadata .attn_state in [
845
- AscendAttentionState .ChunkedPrefill ,
846
- AscendAttentionState .SpecDecoding ,
847
- AscendAttentionState .PrefillCacheHit
848
- ]:
849
- attn_lse = torch .empty (self .num_heads ,
850
- num_tokens ,
851
- dtype = torch .float32 ,
852
- device = query .device )
853
- q_pe = query [..., self .qk_nope_head_dim :]
854
- q_nope = query [..., :self .qk_nope_head_dim ]
855
- mask = torch .triu (
856
- torch .ones (512 , 512 , device = query .device , dtype = query .dtype ),
857
- 1 ) # 512: mask only support 512
858
- if attn_metadata .num_prefills > 1 :
859
- mask = mask .unsqueeze (0 ).repeat (attn_metadata .num_prefills , 1 ,
860
- 1 )
861
- torch_npu .atb .npu_ring_mla (
862
- q_nope = q_nope ,
863
- q_rope = q_pe ,
864
- k_nope = k_nope ,
865
- k_rope = k_pe ,
866
- value = value ,
867
- mask = mask ,
868
- seqlen = torch .tensor (attn_metadata .prefill .query_lens ,
869
- dtype = torch .int32 ),
870
- head_num = self .num_heads ,
871
- kv_head_num = self .num_heads ,
872
- pre_out = None ,
873
- prev_lse = None ,
874
- qk_scale = self .scale ,
875
- kernel_type = "kernel_type_high_precision" ,
876
- mask_type = "mask_type_triu" ,
877
- input_layout = "type_bsnd" ,
878
- calc_type = "calc_type_first_ring" ,
879
- output = attn_output ,
880
- softmax_lse = attn_lse )
881
- attn_output , attn_lse = self ._compute_prefill_context ( \
882
- query , kv_c_and_k_pe_cache , self .qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
883
-
884
- elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
885
- key = torch .cat ((k_nope , k_pe ), dim = - 1 )
886
- torch_npu ._npu_flash_attention (
887
- query = query ,
888
- key = key ,
889
- value = value ,
890
- mask = attn_metadata .attn_mask ,
891
- seq_len = attn_metadata .prefill .context_lens ,
892
- scale_value = self .scale ,
893
- num_heads = self .num_heads ,
894
- num_kv_heads = self .num_heads ,
895
- out = attn_output )
896
- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
897
- else :
898
- raise RuntimeError (
899
- "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
900
- )
901
850
attn_output = attn_output .reshape (
902
851
[num_tokens , self .num_heads * self .v_head_dim ])
903
- if attn_metadata .attn_state in [
904
- AscendAttentionState .ChunkedPrefill ,
905
- AscendAttentionState .SpecDecoding ,
906
- AscendAttentionState .PrefillCacheHit
907
- ] and not ascend_config .chunked_prefill_for_mla :
908
- attn_output = attn_output_torch
909
-
910
852
return attn_output
911
853
912
854
def exec_kv (
@@ -991,91 +933,61 @@ def _forward_decode(
991
933
decode_meta = attn_metadata .decode
992
934
assert decode_meta is not None
993
935
num_tokens = q_nope .size (0 )
994
- if self .running_in_graph or self .running_chunkprefilll_with_torchair :
995
- # shape of knope/k_pe for npu graph mode should be:
996
- # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
997
- block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
998
- actual_seq_lengths = None
999
- if self .enable_kv_nz :
1000
- k_nope = k_nope .view (- 1 , self .num_kv_heads ,
1001
- self .kv_lora_rank // 16 , block_size , 16 )
1002
- k_pe = k_pe .view (- 1 , self .num_kv_heads ,
1003
- self .qk_rope_head_dim // 16 , block_size , 16 )
1004
- input_layout = "BSND"
1005
- else :
1006
- k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
1007
- self .kv_lora_rank )
1008
- k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
1009
- self .qk_rope_head_dim )
1010
- input_layout = "BNSD"
1011
-
1012
- if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
1013
- assert num_tokens % self .spec_token_num == 0
1014
- input_layout = "TND"
1015
- # [bs * q_seq_len, num_heads_per_rank, dim]
1016
- q_nope = q_nope .view (num_tokens , self .num_heads , - 1 )
1017
- q_pe = q_pe .view (num_tokens , self .num_heads , - 1 )
1018
- sparse_mode = 3
1019
- spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
1020
- actual_seq_lengths = decode_meta .actual_seq_lengths_q
1021
- else :
1022
- if self .enable_kv_nz :
1023
- q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
1024
- q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
1025
- else :
1026
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
1027
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
1028
- sparse_mode = 0
1029
- spec_attn_mask = None
1030
-
1031
- attn_output , _ = torch_npu .npu_fused_infer_attention_score (
1032
- q_nope ,
1033
- k_nope ,
1034
- k_nope ,
1035
- query_rope = q_pe ,
1036
- key_rope = k_pe ,
1037
- num_heads = self .num_heads ,
1038
- num_key_value_heads = self .num_kv_heads ,
1039
- input_layout = input_layout ,
1040
- atten_mask = spec_attn_mask ,
1041
- sparse_mode = sparse_mode ,
1042
- scale = self .scale ,
1043
- antiquant_mode = 0 ,
1044
- antiquant_scale = None ,
1045
- block_table = decode_meta .block_table ,
1046
- block_size = block_size ,
1047
- actual_seq_lengths_kv = decode_meta .seq_lens_list ,
1048
- actual_seq_lengths = actual_seq_lengths )
936
+ # shape of knope/k_pe for npu graph mode should be:
937
+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
938
+ block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
939
+ actual_seq_lengths = None
940
+ if self .enable_kv_nz :
941
+ k_nope = k_nope .view (- 1 , self .num_kv_heads ,
942
+ self .kv_lora_rank // 16 , block_size , 16 )
943
+ k_pe = k_pe .view (- 1 , self .num_kv_heads ,
944
+ self .qk_rope_head_dim // 16 , block_size , 16 )
945
+ input_layout = "BSND"
1049
946
else :
1050
- # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
1051
- # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
1052
- # public available
1053
- assert len (kv_c_and_k_pe_cache ) > 1
1054
- if envs .VLLM_ASCEND_MLA_PA :
1055
- attn_output = torch_npu .atb .npu_multi_head_latent_attention (
1056
- q_nope , q_pe , kv_c_and_k_pe_cache [0 ],
1057
- kv_c_and_k_pe_cache [1 ], attn_metadata .decode .block_table ,
1058
- attn_metadata .decode .seq_lens , self .num_heads , self .scale ,
1059
- self .num_kv_heads )
947
+ k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
948
+ self .kv_lora_rank )
949
+ k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
950
+ self .qk_rope_head_dim )
951
+ input_layout = "BNSD"
952
+
953
+ if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
954
+ assert num_tokens % self .spec_token_num == 0
955
+ input_layout = "TND"
956
+ # [bs * q_seq_len, num_heads_per_rank, dim]
957
+ q_nope = q_nope .view (num_tokens , self .num_heads , - 1 )
958
+ q_pe = q_pe .view (num_tokens , self .num_heads , - 1 )
959
+ sparse_mode = 3
960
+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
961
+ actual_seq_lengths = decode_meta .actual_seq_lengths_q
962
+ else :
963
+ if self .enable_kv_nz :
964
+ q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
965
+ q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
1060
966
else :
1061
- q = torch .cat ([q_nope , q_pe ], dim = - 1 )
1062
- attn_output = torch .empty (
1063
- [num_tokens , self .num_heads , self .kv_lora_rank ],
1064
- dtype = q .dtype ,
1065
- device = q .device )
1066
- k_cache = torch .cat (
1067
- [kv_c_and_k_pe_cache [0 ], kv_c_and_k_pe_cache [1 ]], dim = - 1 )
1068
- torch_npu ._npu_paged_attention_mla (
1069
- query = q ,
1070
- key_cache = k_cache ,
1071
- num_kv_heads = self .num_kv_heads ,
1072
- num_heads = self .num_heads ,
1073
- scale_value = self .scale ,
1074
- block_table = attn_metadata .decode .
1075
- block_table , # type:ignore
1076
- context_lens = attn_metadata .decode .seq_lens , # type:ignore
1077
- mla_vheadsize = self .kv_lora_rank ,
1078
- out = attn_output )
967
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
968
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
969
+ sparse_mode = 0
970
+ spec_attn_mask = None
971
+
972
+ attn_output , _ = torch_npu .npu_fused_infer_attention_score (
973
+ q_nope ,
974
+ k_nope ,
975
+ k_nope ,
976
+ query_rope = q_pe ,
977
+ key_rope = k_pe ,
978
+ num_heads = self .num_heads ,
979
+ num_key_value_heads = self .num_kv_heads ,
980
+ input_layout = input_layout ,
981
+ atten_mask = spec_attn_mask ,
982
+ sparse_mode = sparse_mode ,
983
+ scale = self .scale ,
984
+ antiquant_mode = 0 ,
985
+ antiquant_scale = None ,
986
+ block_table = decode_meta .block_table ,
987
+ block_size = block_size ,
988
+ actual_seq_lengths_kv = decode_meta .seq_lens_list ,
989
+ actual_seq_lengths = actual_seq_lengths )
990
+
1079
991
current_ms_metadata = get_multistream_comm_context ()
1080
992
if current_ms_metadata is None :
1081
993
return self ._v_up_proj (attn_output ,
0 commit comments