14
14
from vllm .config import VllmConfig
15
15
from vllm .logger import init_logger
16
16
from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
17
- CommonAttentionMetadata )
18
- from vllm . v1 . core . sched . output import SchedulerOutput
17
+ CommonAttentionMetadata ,
18
+ split_decodes_and_prefills )
19
19
from vllm .v1 .kv_cache_interface import AttentionSpec
20
- from vllm .v1 .worker .gpu_input_batch import InputBatch
21
20
22
21
try :
23
22
import intel_extension_for_pytorch .llm .modules as ipex_modules
@@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
102
101
"""Metadata for PagedAttention."""
103
102
# (batch_size,). The length of sequences (entire tokens seen so far) per
104
103
# sequence.
105
- seq_lens_tensor : Optional [torch .Tensor ]
104
+ decode_seq_lens_tensor : Optional [torch .Tensor ]
106
105
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
107
- max_decode_seq_len : int
106
+ decode_max_seq_len : int
108
107
# (batch_size, max_blocks_per_seq).
109
108
# Block addresses per sequence. (Seq id -> list of physical block)
110
109
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
111
110
# in the kv cache. Each block can contain up to block_size tokens.
112
111
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
113
112
# captured.
114
- block_tables : Optional [torch .Tensor ]
113
+ decode_block_tables : Optional [torch .Tensor ]
115
114
"""Metadata for TorchSDPABackend.
116
115
"""
117
116
# Currently, input sequences can only contain all prompts
@@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
121
120
122
121
# For chunked prefill only
123
122
max_query_len : Optional [int ] = None
124
- max_kv_len : Optional [int ] = None
123
+ prefill_max_seq_len : Optional [int ] = None
125
124
prefill_query_start_loc : Optional [torch .Tensor ] = None
126
- kv_start_loc : Optional [torch .Tensor ] = None
125
+ prefill_seq_start_loc : Optional [torch .Tensor ] = None
127
126
prefill_block_tables : Optional [torch .Tensor ] = None
128
127
129
128
# For V1 logits index only
@@ -307,8 +306,8 @@ def get_seq_len_block_table_args(
307
306
or attn_type == AttentionType .ENCODER_ONLY ):
308
307
# Decoder self-attention
309
308
# Choose max_seq_len based on whether we are in prompt_run
310
- return (self .seq_lens_tensor , self .max_decode_seq_len ,
311
- self .block_tables )
309
+ return (self .decode_seq_lens_tensor , self .decode_max_seq_len ,
310
+ self .decode_block_tables )
312
311
elif attn_type == AttentionType .ENCODER_DECODER :
313
312
# Enc/dec cross-attention KVs match encoder sequence length;
314
313
# cross-attention utilizes special "cross" block tables
@@ -323,19 +322,14 @@ def get_seq_len_block_table_args(
323
322
324
323
325
324
class TorchSDPAMetadataBuilderV1 (AttentionMetadataBuilder [TorchSDPAMetadata ]):
325
+ reorder_batch_threshold : int = 1
326
326
327
327
def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
328
328
vllm_config : VllmConfig , device : torch .device ) -> None :
329
329
super ().__init__ (kv_cache_spec , layer_names , vllm_config , device )
330
330
331
331
self .scheduler_config = vllm_config .scheduler_config
332
-
333
- # For reorder
334
- self .reorder_prompt_req_index_list = np .empty (
335
- vllm_config .scheduler_config .max_num_seqs , dtype = np .int64 )
336
- self .reorder_decode_req_index_list = np .empty (
337
- vllm_config .scheduler_config .max_num_seqs , dtype = np .int64 )
338
- self .num_prompt_req : int = 0
332
+ self ._init_reorder_batch_threshold (1 , False )
339
333
340
334
self .seq_start_loc_cpu = torch .zeros (
341
335
vllm_config .scheduler_config .max_num_seqs + 1 ,
@@ -344,50 +338,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
344
338
)
345
339
self .seq_start_loc_np = self .seq_start_loc_cpu .numpy ()
346
340
347
- def reorder_batch (self , input_batch : InputBatch ,
348
- scheduler_output : SchedulerOutput ) -> bool :
349
- prompt_list_idx = 0
350
- decode_list_idx = 0
351
- for req_index in range (input_batch .num_reqs ):
352
- if input_batch .num_computed_tokens_cpu [
353
- req_index ] < input_batch .num_prompt_tokens [req_index ]:
354
- # prompt stage
355
- self .reorder_prompt_req_index_list [prompt_list_idx ] = req_index
356
- prompt_list_idx += 1
357
- else :
358
- # decode stage
359
- self .reorder_decode_req_index_list [decode_list_idx ] = req_index
360
- decode_list_idx += 1
361
- assert decode_list_idx + prompt_list_idx == input_batch .num_reqs
362
-
363
- # Update prompt requests number
364
- self .num_prompt_req = prompt_list_idx
365
-
366
- reorder_req_num = 0
367
- for req_index in range (decode_list_idx ):
368
- if self .reorder_decode_req_index_list [req_index ] < prompt_list_idx :
369
- reorder_req_num += 1
370
- else :
371
- break
372
-
373
- if reorder_req_num == 0 :
374
- return False
375
-
376
- reorder_prompt_list = (
377
- self .reorder_prompt_req_index_list [:prompt_list_idx ]
378
- [- reorder_req_num :])
379
- reorder_decode_list = (
380
- self .reorder_decode_req_index_list [:decode_list_idx ]
381
- [:reorder_req_num ])
382
- assert reorder_decode_list .size == reorder_prompt_list .size
383
-
384
- for idx in range (reorder_req_num ):
385
- prompt_req_index = reorder_prompt_list [idx ].item ()
386
- decode_req_index = reorder_decode_list [idx ].item ()
387
- input_batch .swap_states (prompt_req_index , decode_req_index )
388
-
389
- return True
390
-
391
341
def build (self ,
392
342
common_prefix_len : int ,
393
343
common_attn_metadata : CommonAttentionMetadata ,
@@ -397,41 +347,46 @@ def build(self,
397
347
398
348
seq_lens_cpu = common_attn_metadata .seq_lens_cpu
399
349
seq_lens_np = seq_lens_cpu .numpy ()
400
- num_prompt_req = self .num_prompt_req
401
- max_prefill_seq_len = seq_lens_np [:num_prompt_req ].max ().item (
402
- ) if num_prompt_req > 0 else 0
403
- max_decode_seq_len = seq_lens_np [num_prompt_req :num_reqs ].max ().item (
404
- ) if num_prompt_req < num_reqs else 0
405
- self .seq_start_loc_np [0 ] = 0
406
- np .cumsum (seq_lens_np , out = self .seq_start_loc_np [1 :num_reqs + 1 ])
407
350
408
351
query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
409
- num_prefill_tokens = int (query_start_loc_cpu [num_prompt_req ].item ())
410
- num_decode_tokens = int (query_start_loc_cpu [num_reqs ].item () -
411
- num_prefill_tokens )
352
+ query_start_loc_np = query_start_loc_cpu .numpy ()
353
+
354
+ num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
355
+ split_decodes_and_prefills (common_attn_metadata ,
356
+ decode_threshold = self .reorder_batch_threshold ,
357
+ require_uniform = True )
358
+
359
+ max_prefill_seq_len = seq_lens_np [num_decodes :num_reqs ].max ().item (
360
+ ) if num_prefills > 0 else 0
361
+ max_decode_seq_len = seq_lens_np [:num_decodes ].max ().item (
362
+ ) if num_prefills < num_reqs else 0
363
+ self .seq_start_loc_np [0 ] = 0
364
+ np .cumsum (seq_lens_np , out = self .seq_start_loc_np [1 :num_reqs + 1 ])
412
365
413
366
slot_mapping = common_attn_metadata .slot_mapping .long ()
414
367
block_table_tensor = common_attn_metadata .block_table_tensor
368
+ query_start_loc_np = query_start_loc_cpu .numpy ()
369
+ query_start_loc_np [num_decodes :num_reqs + 1 ] -= num_decode_tokens
415
370
416
371
attn_metadata = TorchSDPAMetadata (
417
- num_prefills = num_prompt_req ,
372
+ num_prefills = num_prefills ,
418
373
num_prefill_tokens = num_prefill_tokens ,
419
374
num_decode_tokens = num_decode_tokens ,
420
375
slot_mapping = slot_mapping ,
421
376
# to ensure inference when chunked_prefill is disabled
422
377
seq_lens = seq_lens_cpu .tolist (),
423
- seq_lens_tensor = seq_lens_cpu [num_prompt_req : num_reqs ], # decode
424
- max_decode_seq_len = max_decode_seq_len , # decode
425
- block_tables = block_table_tensor [num_prompt_req : num_reqs ], # decode
378
+ decode_seq_lens_tensor = seq_lens_cpu [: num_decodes ], # decode
379
+ decode_max_seq_len = max_decode_seq_len , # decode
380
+ decode_block_tables = block_table_tensor [: num_decodes ], # decode
426
381
chunked_prefill = self .scheduler_config .chunked_prefill_enabled ,
427
382
max_query_len = max_query_len ,
428
- max_kv_len = max_prefill_seq_len ,
429
- prefill_query_start_loc = query_start_loc_cpu [: num_prompt_req +
383
+ prefill_max_seq_len = max_prefill_seq_len ,
384
+ prefill_query_start_loc = query_start_loc_cpu [num_decodes : num_reqs +
430
385
1 ], # prefill
431
- kv_start_loc = self .seq_start_loc_cpu [: num_prompt_req +
432
- 1 ], # prefill
433
- prefill_block_tables = block_table_tensor [:
434
- num_prompt_req ], # prefill
386
+ prefill_seq_start_loc = self .seq_start_loc_cpu [num_decodes : num_reqs +
387
+ 1 ], # prefill
388
+ prefill_block_tables = block_table_tensor [
389
+ num_decodes : num_reqs ], # prefill
435
390
query_start_loc = query_start_loc_cpu [:num_reqs +
436
391
1 ], # for logits index
437
392
)
@@ -596,14 +551,14 @@ def forward(
596
551
import intel_extension_for_pytorch .llm .modules as ipex_modules
597
552
output = torch .empty_like (query )
598
553
ipex_modules .PagedAttention .flash_attn_varlen_func (
599
- output [: prefill_meta .num_prefill_tokens , :, :],
600
- query [: prefill_meta .num_prefill_tokens , :, :],
554
+ output [prefill_meta .num_decode_tokens : , :, :],
555
+ query [prefill_meta .num_decode_tokens : , :, :],
601
556
key_cache ,
602
557
value_cache ,
603
558
prefill_meta .prefill_query_start_loc ,
604
- prefill_meta .kv_start_loc ,
559
+ prefill_meta .prefill_seq_start_loc ,
605
560
prefill_meta .max_query_len ,
606
- prefill_meta .max_kv_len ,
561
+ prefill_meta .prefill_max_seq_len ,
607
562
self .scale ,
608
563
True ,
609
564
prefill_meta .prefill_block_tables ,
@@ -621,8 +576,8 @@ def forward(
621
576
) = decode_meta .get_seq_len_block_table_args (attn_type )
622
577
623
578
self .paged_attn_impl .forward_decode (
624
- output [attn_metadata .num_prefill_tokens : , :, :],
625
- query [attn_metadata .num_prefill_tokens : , :, :],
579
+ output [: attn_metadata .num_decode_tokens , :, :],
580
+ query [: attn_metadata .num_decode_tokens , :, :],
626
581
key_cache ,
627
582
value_cache ,
628
583
block_tables_arg ,
0 commit comments