@@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
214
214
# |-- query_len ---|
215
215
216
216
num_actual_tokens : int # Number of tokens excluding padding.
217
+ num_actual_kv_tokens : int
217
218
max_query_len : int
218
219
query_start_loc : torch .Tensor
219
220
max_seq_len : int
220
221
seq_lens : torch .Tensor
221
222
slot_mapping : torch .Tensor
222
223
block_table : torch .Tensor
224
+ cu_seq_lens : Optional [torch .Tensor ]
223
225
224
226
# For cascade attention.
225
227
use_cascade : bool
@@ -272,6 +274,20 @@ def build(self,
272
274
seq_lens = common_attn_metadata .seq_lens
273
275
block_table_tensor = common_attn_metadata .block_table_tensor
274
276
slot_mapping = common_attn_metadata .slot_mapping
277
+ if max_query_len > 1 :
278
+ # We pre-compute cumulative seq len needed for prefill attention
279
+ # here to avoid recomputing it for every layer
280
+ cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
281
+ dtype = torch .int32 ,
282
+ device = seq_lens .device )
283
+ torch .cumsum (seq_lens ,
284
+ dim = 0 ,
285
+ dtype = cu_seq_lens .dtype ,
286
+ out = cu_seq_lens [1 :])
287
+ num_actual_kv_tokens = int (cu_seq_lens [- 1 ].item ())
288
+ else :
289
+ cu_seq_lens = None
290
+ num_actual_kv_tokens = 0
275
291
276
292
def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
277
293
max_seq_len , causal ):
@@ -281,12 +297,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281
297
282
298
attn_metadata = AiterFlashAttentionMetadata (
283
299
num_actual_tokens = num_actual_tokens ,
300
+ num_actual_kv_tokens = num_actual_kv_tokens ,
284
301
max_query_len = max_query_len ,
285
302
query_start_loc = query_start_loc ,
286
303
max_seq_len = max_seq_len ,
287
304
seq_lens = seq_lens ,
288
305
block_table = block_table_tensor ,
289
306
slot_mapping = slot_mapping ,
307
+ cu_seq_lens = cu_seq_lens ,
290
308
use_cascade = use_cascade ,
291
309
common_prefix_len = common_prefix_len ,
292
310
total_tokens = self .total_tokens ,
@@ -475,16 +493,6 @@ def forward(
475
493
block_table = attn_metadata .block_table
476
494
477
495
if max_seqlen_q > 1 :
478
-
479
- cu_seq_lens = torch .zeros (seqused_k .shape [0 ] + 1 ,
480
- dtype = torch .int32 ,
481
- device = query .device )
482
-
483
- torch .cumsum (seqused_k ,
484
- dim = 0 ,
485
- dtype = cu_seq_lens .dtype ,
486
- out = cu_seq_lens [1 :])
487
-
488
496
torch .ops .vllm .flash_attn_varlen_func (
489
497
query [:num_actual_tokens ],
490
498
key_cache ,
@@ -497,10 +505,10 @@ def forward(
497
505
alibi_slopes = self .alibi_slopes ,
498
506
window_size = self .sliding_window ,
499
507
block_table = block_table ,
500
- cu_seqlens_k = cu_seq_lens ,
508
+ cu_seqlens_k = attn_metadata . cu_seq_lens ,
501
509
k_scale = layer ._k_scale ,
502
510
v_scale = layer ._v_scale ,
503
- total_tokens = attn_metadata .total_tokens ,
511
+ total_tokens = attn_metadata .num_actual_kv_tokens ,
504
512
)
505
513
506
514
_ , num_heads , head_size = query .shape
0 commit comments