@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
53
53
54
54
@dataclass
55
55
class AiterMLAMetadata (MLACommonMetadata ):
56
- # The following 4 tensors are for current version of AITER MLA
56
+ # The following 5 tensors are for current version of AITER MLA
57
57
block_table_bound : Optional [torch .Tensor ] = None
58
58
# The indptr of the paged kv cache, shape: [batch_size + 1]
59
59
paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
63
63
# the paged kv cache, shape: [batch_size]
64
64
paged_kv_last_page_lens : Optional [torch .Tensor ] = None
65
65
66
+ # This is just to make new AITER MLA API work
67
+ # -- MTP support is not added yet.
68
+ qo_indptr : Optional [torch .Tensor ] = None
69
+
66
70
@property
67
71
def prefill_metadata (self ):
68
72
prefill_metadata = super ().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
74
78
prefill_metadata \
75
79
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
76
80
prefill_metadata .block_table_bound = self .block_table_bound
81
+ prefill_metadata .qo_indptr = self .qo_indptr
77
82
78
83
# update the cache
79
84
self ._cached_prefill_metadata = self .__class__ (
@@ -93,6 +98,7 @@ def decode_metadata(self):
93
98
decode_metadata \
94
99
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
95
100
decode_metadata .block_table_bound = self .block_table_bound
101
+ decode_metadata .qo_indptr = self .qo_indptr
96
102
97
103
# update the cache
98
104
self ._cached_decode_metadata = self .__class__ (
@@ -136,6 +142,7 @@ def prepare(self):
136
142
self .paged_kv_indptr : list [int ] = [0 ]
137
143
self .paged_kv_last_page_lens : list [int ] = []
138
144
self .total_blocks = 0
145
+ self .qo_indptr : list [int ] = [0 ]
139
146
140
147
def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
141
148
prefix_cache_hit : bool ):
@@ -208,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
208
215
self .paged_kv_indices .extend (block_table [:block_table_bound ])
209
216
self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
210
217
block_table_bound )
218
+ self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
211
219
212
220
last_page_len = seq_len % self .block_size
213
221
if last_page_len == 0 :
@@ -226,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
226
234
self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
227
235
cuda_graph_pad_size )
228
236
self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
237
+ last_qo_indptr = self .qo_indptr [- 1 ]
238
+ self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
229
239
230
240
# For current version of AITER MLA
231
241
if len (self .paged_kv_indptr ) > 0 :
@@ -245,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
245
255
1 ,
246
256
device = device ,
247
257
dtype = torch .int )
258
+
259
+ qo_indptr = torch .tensor (self .qo_indptr ,
260
+ device = device ,
261
+ dtype = torch .int )
248
262
else :
249
263
paged_kv_indices_tensor = None
250
264
paged_kv_indptr_tensor = None
251
265
paged_kv_last_page_lens_tensor = None
252
266
block_table_bound_tensor = None
267
+ qo_indptr = None
253
268
254
269
metadata .paged_kv_indptr = paged_kv_indptr_tensor
255
270
metadata .paged_kv_indices = paged_kv_indices_tensor
256
271
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
257
272
metadata .block_table_bound = block_table_bound_tensor
273
+ metadata .qo_indptr = qo_indptr
258
274
259
275
return metadata
260
276
@@ -263,21 +279,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
263
279
264
280
@contextmanager
265
281
def graph_capture (self , max_batch_size : int ):
266
- kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
267
- max_batch_size = max_batch_size ,
268
- block_size = self .runner .block_size ,
269
- max_block_per_batch = self .runner .get_max_block_per_batch (),
270
- device = self .runner .device )
282
+ kv_indices , kv_indptr , last_page_lens , qo_indptr = \
283
+ get_aiter_mla_metadata (
284
+ max_batch_size = max_batch_size ,
285
+ block_size = self .runner .block_size ,
286
+ max_block_per_batch = \
287
+ self .runner .get_max_block_per_batch (),
288
+ device = self .runner .device )
271
289
self ._paged_kv_indices_tensor = kv_indices
272
290
self ._paged_kv_indptr_tensor = kv_indptr
273
291
self ._paged_kv_last_page_lens_tensor = last_page_lens
292
+ self ._qo_indptr_tensor = qo_indptr
274
293
275
294
with super ().graph_capture (max_batch_size ):
276
295
yield
277
296
278
297
del self ._paged_kv_indices_tensor
279
298
del self ._paged_kv_indptr_tensor
280
299
del self ._paged_kv_last_page_lens_tensor
300
+ del self ._qo_indptr_tensor
281
301
282
302
def graph_capture_get_metadata_for_batch (
283
303
self ,
@@ -291,10 +311,12 @@ def graph_capture_get_metadata_for_batch(
291
311
paged_kv_indices = self ._paged_kv_indices_tensor
292
312
paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
293
313
batch_size ]
314
+ qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
294
315
295
316
metadata .paged_kv_indptr = paged_kv_indptr
296
317
metadata .paged_kv_indices = paged_kv_indices
297
318
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
319
+ metadata .qo_indptr = qo_indptr
298
320
299
321
return metadata
300
322
@@ -311,6 +333,7 @@ def get_graph_input_buffers(self,
311
333
input_buffers [
312
334
"paged_kv_last_page_lens" ] = attn_metadata .\
313
335
decode_metadata .paged_kv_last_page_lens
336
+ input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
314
337
315
338
return input_buffers
316
339
@@ -330,6 +353,8 @@ def prepare_graph_input_buffers(self,
330
353
input_buffers ["paged_kv_last_page_lens" ].copy_ (
331
354
attn_metadata .decode_metadata .paged_kv_last_page_lens ,
332
355
non_blocking = True )
356
+ input_buffers ["qo_indptr" ].copy_ (
357
+ attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
333
358
334
359
335
360
class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -370,11 +395,9 @@ def _flash_attn_varlen_diff_headdims(
370
395
softmax_scale : float , return_softmax_lse : bool ,
371
396
** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
372
397
output = self .flash_attn_varlen_func (
373
- q = q ,
374
- k = k ,
375
- v = v ,
376
- softmax_scale = softmax_scale ,
377
- return_lse = return_softmax_lse ,
398
+ q ,
399
+ k ,
400
+ v ,
378
401
** kwargs ,
379
402
)
380
403
@@ -394,7 +417,7 @@ def _forward_decode(
394
417
B = q_nope .shape [0 ]
395
418
396
419
q = torch .cat ([q_nope , q_pe ], dim = - 1 )
397
- o = torch .zeros (B ,
420
+ o = torch .empty (B ,
398
421
self .num_heads ,
399
422
self .kv_lora_rank ,
400
423
dtype = q .dtype ,
@@ -403,6 +426,8 @@ def _forward_decode(
403
426
kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
404
427
405
428
aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
429
+ attn_metadata .qo_indptr ,
430
+ attn_metadata .max_query_len ,
406
431
attn_metadata .paged_kv_indptr ,
407
432
attn_metadata .paged_kv_indices ,
408
433
attn_metadata .paged_kv_last_page_lens )
0 commit comments