@@ -65,7 +65,6 @@ class AttentionMaskBuilder:
65
65
def __init__ (self , attn_mask : torch .Tensor ):
66
66
self ._seq_len_cached = attn_mask .shape [0 ]
67
67
self .attn_mask_cache = attn_mask
68
- self .splitfuse_mask_value = - 10000
69
68
70
69
@classmethod
71
70
def initialize_from_len (cls ,
@@ -74,18 +73,25 @@ def initialize_from_len(cls,
74
73
mask_value : Optional [int ] = None ):
75
74
return cls (generate_attn_mask (max_seq_len , dtype , mask_value ))
76
75
77
- def update_attn_cache (self , seqlen : int , dtype : torch .dtype ,
78
- device : torch .device ):
79
- if seqlen > self ._seq_len_cached or self .attn_mask_cache .dtype != dtype :
76
+ @staticmethod
77
+ def get_mask_scale_factor (dtype : torch .dtype = torch .float16 ):
78
+ mask_scale_factor = 1
79
+ if dtype == torch .bfloat16 :
80
+ mask_scale_factor = - 10000
81
+ return mask_scale_factor
82
+
83
+ def update_attn_cache (self , seqlen : int , dtype : torch .dtype ):
84
+ if seqlen > self ._seq_len_cached :
80
85
self ._seq_len_cached = seqlen
81
86
self .attn_mask_cache = generate_attn_mask (seqlen , dtype )
82
- if self .attn_mask_cache .device != device :
83
- self .attn_mask_cache = self .attn_mask_cache .to (device )
87
+ if self .attn_mask_cache .dtype != dtype :
88
+ self .attn_mask_cache = self .attn_mask_cache .to (dtype )
84
89
85
90
def get_attn_mask (self , max_seq_len : int , dtype : torch .dtype ,
86
91
device : torch .device ):
87
- self .update_attn_cache (max_seq_len , dtype , device )
88
- return self .attn_mask_cache [:max_seq_len , :max_seq_len ].contiguous ()
92
+ self .update_attn_cache (max_seq_len , dtype )
93
+ return self .attn_mask_cache [:max_seq_len , :max_seq_len ].contiguous (
94
+ ).to (device )
89
95
90
96
def get_decode_attn_mask (
91
97
self ,
@@ -94,53 +100,28 @@ def get_decode_attn_mask(
94
100
dtype : torch .dtype ,
95
101
device : torch .device ,
96
102
):
97
- self .update_attn_cache (max_s , dtype , device )
103
+ self .update_attn_cache (max_s , dtype )
98
104
return (self .attn_mask_cache .index_select (
99
- 0 , input_lengths )[:, :max_s ].view (- 1 , 1 , max_s ).contiguous ())
105
+ 0 , input_lengths )[:, :max_s ].view (- 1 , 1 ,
106
+ max_s ).contiguous ().to (device ))
100
107
101
108
def get_splitfuse_attn_mask (
102
109
self ,
103
110
seq_lens ,
104
- query_lens ,
105
111
position ,
106
112
dtype ,
107
113
device ,
108
114
) -> torch .Tensor :
109
115
max_seq_len = max (seq_lens , default = 0 )
110
- if max_seq_len <= self ._seq_len_cached :
111
- self .update_attn_cache (max_seq_len , dtype , device )
112
- # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
113
- # is not the same. Fix this in the future when kernel is ready.
114
- if self .attn_mask_cache .numel (
115
- ) > 1 and self .attn_mask_cache [0 ][1 ] > 0 :
116
- attn_mask = self .get_attn_mask ( # type: ignore
117
- max_seq_len , dtype , device )
118
- attn_mask *= - 10000
119
- else :
120
- attn_mask = self .attn_mask_cache
121
- return torch .index_select (attn_mask , dim = 0 ,
122
- index = position )[:, :max_seq_len ]
123
- total_q_len = sum (query_lens )
124
- attn_mask = torch .zeros ((total_q_len , max_seq_len ),
125
- dtype = dtype ,
126
- device = "cpu" )
127
-
128
- current_row = 0
129
- for i in range (len (query_lens )):
130
- seq_len = seq_lens [i ]
131
- q_len = query_lens [i ]
132
- context_len = seq_len - q_len
133
-
134
- assert context_len >= 0
135
- attn_mask [current_row :current_row + q_len ,
136
- context_len :] = self .splitfuse_mask_value
137
- right_tensor = attn_mask [current_row :current_row + q_len ,
138
- context_len :seq_len ]
139
- right_tensor .masked_fill_ (
140
- right_tensor .tril () == self .splitfuse_mask_value , 0 )
141
- current_row += q_len
142
-
143
- return attn_mask .to (device , non_blocking = True )
116
+ self .update_attn_cache (max_seq_len , dtype )
117
+ # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
118
+ # is not the same. Fix this in the future when kernel is ready.
119
+ mask_scale_factor = AttentionMaskBuilder .get_mask_scale_factor (dtype )
120
+ attn_mask = torch .index_select (self .attn_mask_cache ,
121
+ dim = 0 ,
122
+ index = position )[:, :max_seq_len ]
123
+ attn_mask *= mask_scale_factor
124
+ return attn_mask .contiguous ().to (device , non_blocking = True )
144
125
145
126
146
127
class AscendAttentionBackend (AttentionBackend ):
0 commit comments