@@ -119,7 +119,6 @@ class AscendAttentionState(Enum):
119
119
120
120
@dataclass
121
121
class AscendMetadata :
122
-
123
122
# **************************** Basic Properties ************************** #
124
123
attn_mask : Optional [torch .Tensor ] = None
125
124
# Current state of this attention run.
@@ -155,37 +154,50 @@ class AscendMetadata:
155
154
is_only_prefill : bool = False
156
155
157
156
157
+ @dataclass
158
+ class AscendAttentionMetadataBuildInfo :
159
+ num_actual_tokens : int = 0
160
+ block_table : torch .Tensor = None
161
+ query_start_loc : torch .Tensor = None
162
+ query_lens : torch .Tensor = None
163
+ seq_lens : torch .Tensor = None
164
+ max_query_len : int = 0
165
+ slot_mapping : torch .Tensor = None
166
+ attn_mask : torch .Tensor = None
167
+ attn_state : AscendAttentionState = None
168
+ enable_dbo_across_dp : bool = False
169
+ is_only_prefill : bool = False
170
+
171
+
158
172
class AscendAttentionMetadataBuilder :
159
173
160
174
def __init__ (self , runner ):
161
175
self .runner = runner
162
176
163
- def reorder_batch (self , input_batch : "InputBatch" ,
164
- scheduler_output : "SchedulerOutput" ) -> bool :
177
+ def reorder_batch (
178
+ self ,
179
+ input_batch : "InputBatch" ,
180
+ scheduler_output : "SchedulerOutput" ,
181
+ ) -> bool :
165
182
return False
166
183
167
- def build (self ,
168
- num_reqs ,
169
- num_actual_tokens ,
170
- max_query_len ,
171
- enable_dbo_across_dp : bool = False ,
172
- is_only_prefill : bool = False ):
173
-
174
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
175
- )
176
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
177
- block_table [:num_reqs ])
178
-
179
- query_lens = self .runner .query_lens
180
- seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
181
- slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
182
- self .runner .device , non_blocking = True )
183
- attn_mask = self .runner .attn_mask
184
- attn_state = self .runner .attn_state
185
- query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
186
- query_start_loc = query_start_loc_cpu .to (self .runner .device ,
187
- non_blocking = True )
188
-
184
+ def _assemble_build_info (
185
+ self ,
186
+ num_reqs ,
187
+ num_actual_tokens ,
188
+ max_query_len ,
189
+ enable_dbo_across_dp ,
190
+ is_only_prefill ,
191
+ block_table ,
192
+ query_start_loc ,
193
+ query_lens ,
194
+ seq_lens ,
195
+ slot_mapping ,
196
+ attn_mask ,
197
+ attn_state : "AscendAttentionState" ,
198
+ * args ,
199
+ ** kwargs ,
200
+ ) -> "AscendAttentionMetadataBuildInfo" :
189
201
if is_310p ():
190
202
if attn_state == AscendAttentionState .PrefillNoCache :
191
203
mask_nz = nd_to_nz_2d (attn_mask )
@@ -196,9 +208,9 @@ def build(self,
196
208
attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
197
209
ACL_FORMAT_FRACTAL_NZ )
198
210
199
- attn_metadata = AscendMetadata (
211
+ build_info = AscendAttentionMetadataBuildInfo (
200
212
num_actual_tokens = num_actual_tokens ,
201
- block_tables = block_table ,
213
+ block_table = block_table ,
202
214
query_start_loc = query_start_loc ,
203
215
query_lens = query_lens ,
204
216
seq_lens = seq_lens ,
@@ -208,6 +220,61 @@ def build(self,
208
220
attn_state = attn_state ,
209
221
enable_dbo_across_dp = enable_dbo_across_dp ,
210
222
is_only_prefill = is_only_prefill )
223
+ return build_info
224
+
225
+ def _assemble_attn_metadata (
226
+ self ,
227
+ build_info : "AscendAttentionMetadataBuildInfo" ,
228
+ ) -> "AscendMetadata" :
229
+ attn_metadata = AscendMetadata (
230
+ num_actual_tokens = build_info .num_actual_tokens ,
231
+ block_tables = build_info .block_table ,
232
+ query_start_loc = build_info .query_start_loc ,
233
+ query_lens = build_info .query_lens ,
234
+ seq_lens = build_info .seq_lens ,
235
+ max_query_len = build_info .max_query_len ,
236
+ slot_mapping = build_info .slot_mapping ,
237
+ attn_mask = build_info .attn_mask ,
238
+ attn_state = build_info .attn_state ,
239
+ enable_dbo_across_dp = build_info .enable_dbo_across_dp ,
240
+ is_only_prefill = build_info .is_only_prefill )
241
+ return attn_metadata
242
+
243
+ def build (
244
+ self ,
245
+ num_reqs ,
246
+ num_actual_tokens ,
247
+ max_query_len ,
248
+ enable_dbo_across_dp : bool = False ,
249
+ is_only_prefill : bool = False ,
250
+ * args ,
251
+ ** kwargs ,
252
+ ) -> "AscendMetadata" :
253
+ device = self .runner .device
254
+
255
+ block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
256
+ )
257
+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
258
+ block_table [:num_reqs ])
259
+
260
+ query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
261
+ query_start_loc = query_start_loc_cpu .to (device , non_blocking = True )
262
+
263
+ query_lens = self .runner .query_lens
264
+ seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
265
+ slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
266
+ device , non_blocking = True )
267
+ attn_mask = self .runner .attn_mask
268
+ attn_state = self .runner .attn_state
269
+
270
+ build_info = self ._assemble_build_info (num_reqs , num_actual_tokens ,
271
+ max_query_len , block_table ,
272
+ query_start_loc , query_lens ,
273
+ seq_lens , slot_mapping ,
274
+ attn_mask , attn_state , args ,
275
+ kwargs )
276
+
277
+ attn_metadata = self ._assemble_attn_metadata (build_info )
211
278
return attn_metadata
212
279
213
280
0 commit comments