@@ -158,6 +158,18 @@ class AscendMetadata:
158
158
is_only_prefill : bool = False
159
159
160
160
161
+ @dataclass
162
+ class AscendAttentionMetadataBuildInfo :
163
+ num_actual_tokens : int
164
+ block_table : torch .Tensor
165
+ query_start_loc : torch .Tensor
166
+ query_lens : torch .Tensor
167
+ seq_lens : torch .Tensor
168
+ slot_mapping : torch .Tensor
169
+ attn_mask : torch .Tensor
170
+ attn_state : AscendAttentionState
171
+
172
+
161
173
class AscendAttentionMetadataBuilder :
162
174
163
175
def __init__ (
@@ -175,9 +187,60 @@ def reorder_batch(self, input_batch: "InputBatch",
175
187
scheduler_output : "SchedulerOutput" ) -> bool :
176
188
return False
177
189
190
+ def _assemble_build_info (
191
+ self ,
192
+ num_actual_tokens ,
193
+ block_table ,
194
+ query_start_loc ,
195
+ query_lens ,
196
+ seq_lens ,
197
+ slot_mapping ,
198
+ attn_mask ,
199
+ attn_state : "AscendAttentionState" ,
200
+ ) -> "AscendAttentionMetadataBuildInfo" :
201
+ if is_310p ():
202
+ if attn_state == AscendAttentionState .PrefillNoCache :
203
+ mask_nz = nd_to_nz_2d (attn_mask )
204
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
205
+ ACL_FORMAT_FRACTAL_NZ )
206
+ elif attn_state == AscendAttentionState .ChunkedPrefill :
207
+ mask_nz = nd_to_nz_spec (attn_mask )
208
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
209
+ ACL_FORMAT_FRACTAL_NZ )
210
+
211
+ build_info = AscendAttentionMetadataBuildInfo (
212
+ num_actual_tokens = num_actual_tokens ,
213
+ block_table = block_table ,
214
+ query_start_loc = query_start_loc ,
215
+ query_lens = query_lens ,
216
+ seq_lens = seq_lens ,
217
+ slot_mapping = slot_mapping ,
218
+ attn_mask = attn_mask ,
219
+ attn_state = attn_state )
220
+ return build_info
221
+
222
+ def _assemble_attn_metadata (
223
+ self ,
224
+ build_info : "AscendAttentionMetadataBuildInfo" ,
225
+ common_attn_metadata : "AscendCommonAttentionMetadata" ,
226
+ ) -> "AscendMetadata" :
227
+ attn_metadata = AscendMetadata (
228
+ num_actual_tokens = build_info .num_actual_tokens ,
229
+ block_tables = build_info .block_table ,
230
+ query_start_loc = build_info .query_start_loc ,
231
+ query_lens = build_info .query_lens ,
232
+ seq_lens = build_info .seq_lens ,
233
+ max_query_len = common_attn_metadata .max_query_len ,
234
+ slot_mapping = build_info .slot_mapping ,
235
+ attn_mask = build_info .attn_mask ,
236
+ attn_state = build_info .attn_state ,
237
+ enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
238
+ is_only_prefill = common_attn_metadata .is_only_prefill )
239
+ return attn_metadata
240
+
178
241
def build (
179
242
self ,
180
- common_attn_metadata : AscendCommonAttentionMetadata ,
243
+ common_attn_metadata : " AscendCommonAttentionMetadata" ,
181
244
model : nn .Module ,
182
245
):
183
246
num_reqs = common_attn_metadata .num_reqs
@@ -205,28 +268,12 @@ def build(
205
268
query_start_loc = query_start_loc_cpu .to (self .device ,
206
269
non_blocking = True )
207
270
208
- if is_310p ():
209
- if attn_state == AscendAttentionState .PrefillNoCache :
210
- mask_nz = nd_to_nz_2d (attn_mask )
211
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
212
- ACL_FORMAT_FRACTAL_NZ )
213
- elif attn_state == AscendAttentionState .ChunkedPrefill :
214
- mask_nz = nd_to_nz_spec (attn_mask )
215
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
216
- ACL_FORMAT_FRACTAL_NZ )
217
-
218
- attn_metadata = AscendMetadata (
219
- num_actual_tokens = num_actual_tokens ,
220
- block_tables = block_table ,
221
- query_start_loc = query_start_loc ,
222
- query_lens = query_lens ,
223
- seq_lens = seq_lens ,
224
- max_query_len = common_attn_metadata .max_query_len ,
225
- slot_mapping = slot_mapping ,
226
- attn_mask = attn_mask ,
227
- attn_state = attn_state ,
228
- enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
229
- is_only_prefill = common_attn_metadata .is_only_prefill )
271
+ build_info = self ._assemble_build_info (num_actual_tokens , block_table ,
272
+ query_start_loc , query_lens ,
273
+ seq_lens , slot_mapping ,
274
+ attn_mask , attn_state )
275
+ attn_metadata = self ._assemble_attn_metadata (build_info ,
276
+ common_attn_metadata )
230
277
return attn_metadata
231
278
232
279
0 commit comments