1
1
from dataclasses import dataclass
2
- from typing import TYPE_CHECKING , NamedTuple , Optional , Tuple , Type , TypeVar
2
+ from typing import (TYPE_CHECKING , ClassVar , NamedTuple , Optional , Tuple , Type ,
3
+ TypeVar )
3
4
4
5
import torch
5
6
import torch_npu
12
13
from vllm .model_executor .layers .linear import (LinearBase ,
13
14
UnquantizedLinearMethod )
14
15
from vllm .utils import cdiv , round_down
16
+ from vllm .v1 .attention .backends .utils import AttentionCGSupport
15
17
16
18
from vllm_ascend .ascend_config import get_ascend_config
17
19
from vllm_ascend .attention .attention_v1 import AscendAttentionState
@@ -77,9 +79,9 @@ class ChunkedContextMetadata:
77
79
block_table : torch .Tensor
78
80
max_query_len : int
79
81
max_seq_lens : int
82
+ sin : torch .Tensor
83
+ cos : torch .Tensor
80
84
chunked_context : Optional [ChunkedContextMetadata ] = None
81
- sin : torch .Tensor = None
82
- cos : torch .Tensor = None
83
85
84
86
85
87
@dataclass
@@ -91,10 +93,10 @@ class AscendSFADecodeMetadata:
91
93
seq_lens : torch .Tensor
92
94
max_seq_lens : int
93
95
seq_lens_list : list [int ]
94
- actual_seq_lengths_q : Optional [torch .Tensor ] = None
96
+ actual_seq_lengths_q : torch .Tensor
97
+ sin : torch .Tensor
98
+ cos : torch .Tensor
95
99
attn_mask : Optional [torch .Tensor ] = None
96
- sin : torch .Tensor = None
97
- cos : torch .Tensor = None
98
100
99
101
100
102
@dataclass
@@ -163,6 +165,9 @@ def split_metadata_for_multistream(
163
165
164
166
165
167
class AscendSFAMetadataBuilder :
168
+ # Does this backend/builder support ACL Graphs for attention (default: no).
169
+ aclgraph_support : ClassVar [AttentionCGSupport ] = \
170
+ AttentionCGSupport .NEVER
166
171
"""
167
172
NOTE: Please read the comment at the top of the file before trying to
168
173
understand this class
@@ -292,11 +297,10 @@ def build(
292
297
device = self .device
293
298
294
299
block_table = (common_attn_metadata .block_table_tensor [:num_reqs ])
295
- slot_mapping = common_attn_metadata .slot_mapping_cpu [:
296
- num_actual_tokens ].to (
297
- device ,
298
- non_blocking =
299
- True )
300
+ slot_mapping = common_attn_metadata .slot_mapping [:
301
+ num_actual_tokens ].to (
302
+ device ,
303
+ non_blocking = True )
300
304
input_positions = common_attn_metadata .positions [:
301
305
num_actual_tokens ].long (
302
306
)
@@ -686,8 +690,7 @@ def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
686
690
topk_indices = self .indexer_select (hidden_states_decode ,
687
691
decode_q_c ,
688
692
attn_metadata = attn_metadata ,
689
- kv_cache = kv_cache ,
690
- is_prefill = False )
693
+ kv_cache = kv_cache )
691
694
692
695
query_states = (decode_q_nope , decode_q_pe )
693
696
key_states = (decode_k_nope , decode_k_rope )
@@ -775,8 +778,7 @@ def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
775
778
topk_indices = self .indexer_select (x = hidden_states_prefill ,
776
779
qr = prefill_qr ,
777
780
kv_cache = kv_cache ,
778
- attn_metadata = attn_metadata ,
779
- is_prefill = True )
781
+ attn_metadata = attn_metadata )
780
782
query_states = (prefill_q_nope , prefill_q_pe )
781
783
key_states = (prefill_k_nope , prefill_k_pe )
782
784
prefill_preprocess_res = PrefillSFAPreprocessResult (
@@ -826,45 +828,27 @@ def forward(
826
828
query_states = decode_preprocess_res .query_states ,
827
829
key_states = decode_preprocess_res .key_states ,
828
830
attn_metadata = attn_metadata ,
829
- attention_mask = None ,
830
- kv_cache = kv_cache ,
831
- topk_indices = decode_preprocess_res .topk_indices ,
832
- is_prefill = False ,
833
- bsz = decode_preprocess_res .bsz )
831
+ topk_indices = decode_preprocess_res .topk_indices )
834
832
o_proj_input [:num_decode_tokens ] = decode_attn_output
835
833
836
834
if prefill_preprocess_res is not None :
837
835
prefill_attn_output = self .apply_attention_fusion (
838
836
query_states = prefill_preprocess_res .query_states ,
839
837
key_states = prefill_preprocess_res .key_states ,
840
838
attn_metadata = attn_metadata ,
841
- attention_mask = None ,
842
- kv_cache = kv_cache ,
843
- topk_indices = prefill_preprocess_res .topk_indices ,
844
- is_prefill = True ,
845
- bsz = None )
839
+ topk_indices = prefill_preprocess_res .topk_indices )
846
840
o_proj_input [num_decode_tokens :] = prefill_attn_output
847
841
848
842
output [...] = self .mla_epilog (o_proj_input , absorb = True )
849
843
return output
850
844
851
- def apply_attention_fusion (
852
- self ,
853
- query_states ,
854
- key_states ,
855
- topk_indices ,
856
- attn_metadata : M ,
857
- attention_mask : Optional [torch .Tensor ] = None ,
858
- # actual_seq_qlen: torch.Tensor = None,
859
- # actual_seq_lengths_kv: torch.Tensor = None,
860
- kv_cache : Tuple [torch .Tensor ] = None ,
861
- is_prefill : bool = True ,
862
- bsz : int = None ):
845
+ def apply_attention_fusion (self , query_states , key_states , topk_indices ,
846
+ attn_metadata : M ):
863
847
# repeat k/v heads if n_kv_heads < n_heads
864
848
q_nope , q_pe = query_states
865
849
k_nope , k_rope = key_states
866
850
867
- if is_prefill :
851
+ if attn_metadata . prefill is not None :
868
852
869
853
prefill_metadata = attn_metadata .prefill
870
854
@@ -885,7 +869,7 @@ def apply_attention_fusion(
885
869
sparse_mode = 3 ,
886
870
)
887
871
888
- else :
872
+ elif attn_metadata . decode is not None :
889
873
decode_metadata = attn_metadata .decode
890
874
891
875
slc_fa_fusion = torch .ops .custom .npu_selected_flash_attention (
@@ -937,14 +921,19 @@ def indexer_select(
937
921
qr : torch .Tensor ,
938
922
kv_cache : Tuple [torch .Tensor ],
939
923
attn_metadata : M ,
940
- is_prefill : bool = True ,
941
924
):
942
- if is_prefill :
925
+ if attn_metadata . prefill is not None :
943
926
cos = attn_metadata .prefill .cos
944
927
sin = attn_metadata .prefill .sin
945
- else :
928
+ actual_seq_lengths_query = attn_metadata .prefill .query_lens
929
+ actual_seq_lengths_key = attn_metadata .prefill .seq_lens
930
+ block_table = attn_metadata .prefill .block_table
931
+ elif attn_metadata .decode is not None :
946
932
cos = attn_metadata .decode .cos
947
933
sin = attn_metadata .decode .sin
934
+ actual_seq_lengths_query = attn_metadata .decode .actual_seq_lengths_q
935
+ actual_seq_lengths_key = attn_metadata .decode .seq_lens
936
+ block_table = attn_metadata .decode .block_table
948
937
949
938
cos_q , sin_q = cos , sin
950
939
cos = cos .view (- 1 , 1 , 1 , self .qk_rope_head_dim )
@@ -982,17 +971,6 @@ def indexer_select(
982
971
k .shape [- 1 ])) # b, s, n, d
983
972
984
973
weights = self .weights_proj (x )
985
- actual_seq_lengths_query = None
986
- actual_seq_lengths_key = None
987
- block_table = None
988
- if is_prefill :
989
- actual_seq_lengths_query = attn_metadata .prefill .query_lens
990
- actual_seq_lengths_key = attn_metadata .prefill .seq_lens
991
- block_table = attn_metadata .prefill .block_table
992
- else :
993
- actual_seq_lengths_query = attn_metadata .decode .actual_seq_lengths_q
994
- actual_seq_lengths_key = attn_metadata .decode .seq_lens
995
- block_table = attn_metadata .decode .block_table
996
974
997
975
topk_indices = torch .ops .custom .npu_lightning_indexer (
998
976
query = q ,
0 commit comments