@@ -119,7 +119,6 @@ class AscendAttentionState(Enum):
119
119
120
120
@dataclass
121
121
class AscendMetadata :
122
-
123
122
# **************************** Basic Properties ************************** #
124
123
attn_mask : Optional [torch .Tensor ] = None
125
124
# Current state of this attention run.
@@ -155,37 +154,106 @@ class AscendMetadata:
155
154
is_only_prefill : bool = False
156
155
157
156
157
+ @dataclass
158
+ class AscendAttentionMetadataBuildInfo :
159
+ block_tables : torch .Tensor = None
160
+ query_start_loc : torch .Tensor = None
161
+ query_lens : torch .Tensor = None
162
+ seq_lens : torch .Tensor = None
163
+ slot_mapping : torch .Tensor = None
164
+ attn_mask : torch .Tensor = None
165
+ attn_state : AscendAttentionState = None
166
+
167
+
158
168
class AscendAttentionMetadataBuilder :
159
169
160
170
def __init__ (self , runner ):
161
171
self .runner = runner
162
172
163
- def reorder_batch (self , input_batch : "InputBatch" ,
164
- scheduler_output : "SchedulerOutput" ) -> bool :
173
+ def reorder_batch (
174
+ self ,
175
+ input_batch : "InputBatch" ,
176
+ scheduler_output : "SchedulerOutput" ,
177
+ ) -> bool :
165
178
return False
166
179
167
- def build (self ,
168
- num_reqs ,
169
- num_actual_tokens ,
170
- max_query_len ,
171
- enable_dbo_across_dp : bool = False ,
172
- is_only_prefill : bool = False ):
180
+ def _assemble_build_info (
181
+ self ,
182
+ num_reqs ,
183
+ num_actual_tokens ,
184
+ max_query_len ,
185
+ block_tables ,
186
+ query_start_loc ,
187
+ query_lens ,
188
+ seq_lens ,
189
+ slot_mapping ,
190
+ attn_mask ,
191
+ attn_state : "AscendAttentionState" ,
192
+ * args ,
193
+ ** kwargs ,
194
+ ) -> "AscendAttentionMetadataBuildInfo" :
195
+ build_info = AscendAttentionMetadataBuildInfo (
196
+ block_tables = block_tables ,
197
+ query_start_loc = query_start_loc ,
198
+ query_lens = query_lens ,
199
+ seq_lens = seq_lens ,
200
+ slot_mapping = slot_mapping ,
201
+ attn_mask = attn_mask ,
202
+ attn_state = attn_state )
203
+ return build_info
204
+
205
+ def _prepare_build_info (
206
+ self ,
207
+ num_reqs ,
208
+ num_actual_tokens ,
209
+ max_query_len ,
210
+ enable_dbo_across_dp ,
211
+ is_only_prefill ,
212
+ * args ,
213
+ ** kwargs ,
214
+ ) -> "AscendAttentionMetadataBuildInfo" :
215
+ device = self .runner .device
216
+
217
+ block_tables = self .runner .input_batch .block_table [
218
+ 0 ].get_device_tensor ()
219
+ block_tables [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
220
+ block_tables [:num_reqs ])
173
221
174
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
175
- )
176
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
177
- block_table [:num_reqs ])
222
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
223
+ query_start_loc = query_start_loc_cpu .to (device , non_blocking = True )
178
224
179
225
query_lens = self .runner .query_lens
180
226
seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
181
227
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
182
- self . runner . device , non_blocking = True )
228
+ device , non_blocking = True )
183
229
attn_mask = self .runner .attn_mask
184
230
attn_state = self .runner .attn_state
185
- query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
186
- query_start_loc = query_start_loc_cpu .to (self .runner .device ,
187
- non_blocking = True )
188
231
232
+ build_info = self ._assemble_build_info (num_reqs , num_actual_tokens ,
233
+ max_query_len , block_tables ,
234
+ query_start_loc , query_lens ,
235
+ seq_lens , slot_mapping ,
236
+ attn_mask , attn_state , args ,
237
+ kwargs )
238
+ return build_info
239
+
240
+ def build (
241
+ self ,
242
+ num_reqs ,
243
+ num_actual_tokens ,
244
+ max_query_len ,
245
+ enable_dbo_across_dp : bool = False ,
246
+ is_only_prefill : bool = False ,
247
+ * args ,
248
+ ** kwargs ,
249
+ ):
250
+ build_info = self ._prepare_build_info (num_reqs , num_actual_tokens ,
251
+ max_query_len ,
252
+ enable_dbo_across_dp ,
253
+ is_only_prefill , args , kwargs )
254
+
255
+ attn_mask = build_info .attn_mask
256
+ attn_state = build_info .attn_state
189
257
if is_310p ():
190
258
if attn_state == AscendAttentionState .PrefillNoCache :
191
259
mask_nz = nd_to_nz_2d (attn_mask )
@@ -198,12 +266,12 @@ def build(self,
198
266
199
267
attn_metadata = AscendMetadata (
200
268
num_actual_tokens = num_actual_tokens ,
201
- block_tables = block_table ,
202
- query_start_loc = query_start_loc ,
203
- query_lens = query_lens ,
204
- seq_lens = seq_lens ,
269
+ block_tables = build_info . block_tables ,
270
+ query_start_loc = build_info . query_start_loc ,
271
+ query_lens = build_info . query_lens ,
272
+ seq_lens = build_info . seq_lens ,
205
273
max_query_len = max_query_len ,
206
- slot_mapping = slot_mapping ,
274
+ slot_mapping = build_info . slot_mapping ,
207
275
attn_mask = attn_mask ,
208
276
attn_state = attn_state ,
209
277
enable_dbo_across_dp = enable_dbo_across_dp ,
0 commit comments