@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
53
53
54
54
@dataclass
55
55
class AiterMLAMetadata (MLACommonMetadata ):
56
- # The following 5 tensors are for current version of AITER MLA
56
+ # The following 4 tensors are for current version of AITER MLA
57
57
block_table_bound : Optional [torch .Tensor ] = None
58
58
# The indptr of the paged kv cache, shape: [batch_size + 1]
59
59
paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata):
63
63
# the paged kv cache, shape: [batch_size]
64
64
paged_kv_last_page_lens : Optional [torch .Tensor ] = None
65
65
66
- # This is just to make new AITER MLA API work
67
- # -- MTP support is not added yet.
68
- qo_indptr : Optional [torch .Tensor ] = None
69
-
70
66
@property
71
67
def prefill_metadata (self ):
72
68
prefill_metadata = super ().prefill_metadata
@@ -78,7 +74,6 @@ def prefill_metadata(self):
78
74
prefill_metadata \
79
75
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
80
76
prefill_metadata .block_table_bound = self .block_table_bound
81
- prefill_metadata .qo_indptr = self .qo_indptr
82
77
83
78
# update the cache
84
79
self ._cached_prefill_metadata = self .__class__ (
@@ -98,7 +93,6 @@ def decode_metadata(self):
98
93
decode_metadata \
99
94
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
100
95
decode_metadata .block_table_bound = self .block_table_bound
101
- decode_metadata .qo_indptr = self .qo_indptr
102
96
103
97
# update the cache
104
98
self ._cached_decode_metadata = self .__class__ (
@@ -142,7 +136,6 @@ def prepare(self):
142
136
self .paged_kv_indptr : list [int ] = [0 ]
143
137
self .paged_kv_last_page_lens : list [int ] = []
144
138
self .total_blocks = 0
145
- self .qo_indptr : list [int ] = [0 ]
146
139
147
140
def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
148
141
prefix_cache_hit : bool ):
@@ -215,7 +208,6 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
215
208
self .paged_kv_indices .extend (block_table [:block_table_bound ])
216
209
self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
217
210
block_table_bound )
218
- self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
219
211
220
212
last_page_len = seq_len % self .block_size
221
213
if last_page_len == 0 :
@@ -234,8 +226,6 @@ def build(self, seq_lens: list[int], query_lens: list[int],
234
226
self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
235
227
cuda_graph_pad_size )
236
228
self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
237
- last_qo_indptr = self .qo_indptr [- 1 ]
238
- self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
239
229
240
230
# For current version of AITER MLA
241
231
if len (self .paged_kv_indptr ) > 0 :
@@ -255,22 +245,16 @@ def build(self, seq_lens: list[int], query_lens: list[int],
255
245
1 ,
256
246
device = device ,
257
247
dtype = torch .int )
258
-
259
- qo_indptr = torch .tensor (self .qo_indptr ,
260
- device = device ,
261
- dtype = torch .int )
262
248
else :
263
249
paged_kv_indices_tensor = None
264
250
paged_kv_indptr_tensor = None
265
251
paged_kv_last_page_lens_tensor = None
266
252
block_table_bound_tensor = None
267
- qo_indptr = None
268
253
269
254
metadata .paged_kv_indptr = paged_kv_indptr_tensor
270
255
metadata .paged_kv_indices = paged_kv_indices_tensor
271
256
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
272
257
metadata .block_table_bound = block_table_bound_tensor
273
- metadata .qo_indptr = qo_indptr
274
258
275
259
return metadata
276
260
@@ -279,25 +263,21 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
279
263
280
264
@contextmanager
281
265
def graph_capture (self , max_batch_size : int ):
282
- kv_indices , kv_indptr , last_page_lens , qo_indptr = \
283
- get_aiter_mla_metadata (
284
- max_batch_size = max_batch_size ,
285
- block_size = self .runner .block_size ,
286
- max_block_per_batch = \
287
- self .runner .get_max_block_per_batch (),
288
- device = self .runner .device )
266
+ kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
267
+ max_batch_size = max_batch_size ,
268
+ block_size = self .runner .block_size ,
269
+ max_block_per_batch = self .runner .get_max_block_per_batch (),
270
+ device = self .runner .device )
289
271
self ._paged_kv_indices_tensor = kv_indices
290
272
self ._paged_kv_indptr_tensor = kv_indptr
291
273
self ._paged_kv_last_page_lens_tensor = last_page_lens
292
- self ._qo_indptr_tensor = qo_indptr
293
274
294
275
with super ().graph_capture (max_batch_size ):
295
276
yield
296
277
297
278
del self ._paged_kv_indices_tensor
298
279
del self ._paged_kv_indptr_tensor
299
280
del self ._paged_kv_last_page_lens_tensor
300
- del self ._qo_indptr_tensor
301
281
302
282
def graph_capture_get_metadata_for_batch (
303
283
self ,
@@ -311,12 +291,10 @@ def graph_capture_get_metadata_for_batch(
311
291
paged_kv_indices = self ._paged_kv_indices_tensor
312
292
paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
313
293
batch_size ]
314
- qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
315
294
316
295
metadata .paged_kv_indptr = paged_kv_indptr
317
296
metadata .paged_kv_indices = paged_kv_indices
318
297
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
319
- metadata .qo_indptr = qo_indptr
320
298
321
299
return metadata
322
300
@@ -333,7 +311,6 @@ def get_graph_input_buffers(self,
333
311
input_buffers [
334
312
"paged_kv_last_page_lens" ] = attn_metadata .\
335
313
decode_metadata .paged_kv_last_page_lens
336
- input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
337
314
338
315
return input_buffers
339
316
@@ -353,8 +330,6 @@ def prepare_graph_input_buffers(self,
353
330
input_buffers ["paged_kv_last_page_lens" ].copy_ (
354
331
attn_metadata .decode_metadata .paged_kv_last_page_lens ,
355
332
non_blocking = True )
356
- input_buffers ["qo_indptr" ].copy_ (
357
- attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
358
333
359
334
360
335
class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -395,9 +370,11 @@ def _flash_attn_varlen_diff_headdims(
395
370
softmax_scale : float , return_softmax_lse : bool ,
396
371
** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
397
372
output = self .flash_attn_varlen_func (
398
- q ,
399
- k ,
400
- v ,
373
+ q = q ,
374
+ k = k ,
375
+ v = v ,
376
+ softmax_scale = softmax_scale ,
377
+ return_lse = return_softmax_lse ,
401
378
** kwargs ,
402
379
)
403
380
@@ -417,7 +394,7 @@ def _forward_decode(
417
394
B = q_nope .shape [0 ]
418
395
419
396
q = torch .cat ([q_nope , q_pe ], dim = - 1 )
420
- o = torch .empty (B ,
397
+ o = torch .zeros (B ,
421
398
self .num_heads ,
422
399
self .kv_lora_rank ,
423
400
dtype = q .dtype ,
@@ -426,8 +403,6 @@ def _forward_decode(
426
403
kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
427
404
428
405
aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
429
- attn_metadata .qo_indptr ,
430
- attn_metadata .max_query_len ,
431
406
attn_metadata .paged_kv_indptr ,
432
407
attn_metadata .paged_kv_indices ,
433
408
attn_metadata .paged_kv_last_page_lens )
0 commit comments