@@ -1587,7 +1587,7 @@ def _form_prefill_batch(self, contents):
1587
1587
1588
1588
def _create_dummy_prefill_batch_contents (
1589
1589
self , num_prefills : int ) -> list [PrefillInputData ]:
1590
- req_id = - 1
1590
+ req_id = str ( - 1 )
1591
1591
context_len = 0
1592
1592
query_len = 128
1593
1593
prompt_tokens = 128
@@ -1616,26 +1616,30 @@ def _create_dummy_prefill_batch_contents(
1616
1616
def _prepare_prefill_inputs (
1617
1617
self , num_prefills , num_decodes , num_scheduled_tokens : list [int ]
1618
1618
) -> tuple [PrefillInputData , Optional [PrefillInputData ]]:
1619
- all_batch_contents , num_pad_across_dp = self ._extract_prefill_batch_contents (
1620
- num_prefills , num_decodes , num_scheduled_tokens )
1619
+ all_batch_contents , num_pad_across_dp = \
1620
+ self ._extract_prefill_batch_contents (
1621
+ num_prefills , num_decodes , num_scheduled_tokens )
1621
1622
all_batches = [
1622
1623
self ._form_prefill_batch (bc ) for bc in all_batch_contents
1623
1624
]
1624
1625
merge_contents (all_batches [0 ], * all_batches [1 :])
1625
1626
1626
1627
dummy_prefill_input_batches = None
1627
1628
if num_pad_across_dp > 0 :
1628
- dummy_prefill_input_batches = self . _create_dummy_prefill_batch_contents (
1629
- num_pad_across_dp )
1629
+ dummy_prefill_input_batches = \
1630
+ self . _create_dummy_prefill_batch_contents ( num_pad_across_dp )
1630
1631
merge_contents (dummy_prefill_input_batches [0 ],
1631
1632
* dummy_prefill_input_batches [1 :])
1632
1633
return all_batches [0 ], dummy_prefill_input_batches [
1633
1634
0 ] if dummy_prefill_input_batches else None
1634
1635
1635
1636
def _create_decode_input_data (
1636
- self , num_decodes , num_scheduled_tokens , context_lens ,
1637
- block_table_cpu_tensor , num_computed_tokens_cpu ,
1638
- token_ids_cpu ) -> tuple [DecodeInputData , int ]:
1637
+ self ,
1638
+ num_decodes ,
1639
+ num_scheduled_tokens ,
1640
+ context_lens ,
1641
+ block_table_cpu_tensor ,
1642
+ scheduler_output = None ) -> tuple [DecodeInputData , int ]:
1639
1643
# NOTE(kzawora): the +1 is what causes this entire thing to work,
1640
1644
# as in the paged attention, we don't fetch just the context from cache,
1641
1645
# but also kvs for the current token
@@ -1842,7 +1846,10 @@ def _create_decode_input_data(
1842
1846
spec_decode_metadata = spec_decode_metadata ), num_pad_across_dp
1843
1847
1844
1848
def _prepare_decode_inputs (
1845
- self , num_decodes , num_scheduled_tokens
1849
+ self ,
1850
+ num_decodes ,
1851
+ num_scheduled_tokens ,
1852
+ scheduler_output = None
1846
1853
) -> tuple [DecodeInputData , Optional [DecodeInputData ]]:
1847
1854
# Decodes run as one single padded batch with shape [batch, 1]
1848
1855
#
@@ -1861,9 +1868,7 @@ def _prepare_decode_inputs(
1861
1868
return self ._create_decode_input_data (
1862
1869
num_decodes , num_scheduled_tokens ,
1863
1870
self .input_batch .num_computed_tokens_cpu [:num_decodes ],
1864
- self .input_batch .block_table [0 ].get_cpu_tensor (),
1865
- self .input_batch .num_computed_tokens_cpu ,
1866
- self .input_batch .token_ids_cpu )
1871
+ self .input_batch .block_table [0 ].get_cpu_tensor (), scheduler_output )
1867
1872
1868
1873
def _create_dummy_decode_input_data (self ) -> DecodeInputData :
1869
1874
# create dummy decode input data with batch size 1
@@ -1872,12 +1877,13 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData:
1872
1877
context_lens = [128 ]
1873
1878
block_table_cpu_tensor = torch .zeros ([self ._PAD_BLOCK_ID ],
1874
1879
dtype = torch .int32 ).reshape (1 , - 1 )
1875
- num_computed_tokens_cpu = np .array ([128 ], dtype = np .int32 )
1876
- token_ids = np .array (list (int (i ) for i in range (context_lens [0 ])))
1880
+ # num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1881
+ # token_ids = np.array(list(int(i) for i in range(context_lens[0])))
1877
1882
1878
- return self ._create_decode_input_data (
1879
- num_dummy_decodes , num_dummy_scheduled_tokens , context_lens ,
1880
- block_table_cpu_tensor , num_computed_tokens_cpu , token_ids )[0 ]
1883
+ return self ._create_decode_input_data (num_dummy_decodes ,
1884
+ num_dummy_scheduled_tokens ,
1885
+ context_lens ,
1886
+ block_table_cpu_tensor )[0 ]
1881
1887
1882
1888
def _get_cumsum_and_arange (
1883
1889
self ,
@@ -2052,8 +2058,7 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
2052
2058
if not seen and not warmup_mode :
2053
2059
logger .warning ("Configuration: %s was not warmed-up!" , cfg )
2054
2060
2055
- def get_dp_padding (self ,
2056
- num_tokens : int ) -> tuple [int , Optional [torch .Tensor ]]:
2061
+ def get_dp_padding (self , num_tokens : int ) -> int :
2057
2062
dp_size = self .vllm_config .parallel_config .data_parallel_size
2058
2063
dp_rank = self .vllm_config .parallel_config .data_parallel_rank
2059
2064
@@ -2364,9 +2369,11 @@ def execute_model(
2364
2369
with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
2365
2370
prefill_input_data , decode_input_data = self ._prepare_inputs (
2366
2371
scheduler_output , num_prefills , num_decodes )
2367
- prefill_data , dummy_prefill_input_data_batches_across_dp = prefill_input_data
2368
- num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len (
2369
- dummy_prefill_input_data_batches_across_dp .request_ids )
2372
+ prefill_data , \
2373
+ dummy_prefill_input_data_batches_across_dp = prefill_input_data
2374
+ num_pad_prefill_batch_across_dp = \
2375
+ 0 if dummy_prefill_input_data_batches_across_dp is None \
2376
+ else len (dummy_prefill_input_data_batches_across_dp .request_ids )
2370
2377
decode_data , dummy_decode_input_data_across_dp = decode_input_data
2371
2378
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
2372
2379
# later.
@@ -2477,7 +2484,7 @@ def execute_model(
2477
2484
zip (* shallow_tuple (
2478
2485
dummy_prefill_input_data_batches_across_dp ))):
2479
2486
htorch .core .mark_step ()
2480
- _ , dummy_logits_device = \
2487
+ _ , _ , dummy_logits_device = \
2481
2488
self ._execute_model_generic (
2482
2489
token_ids ,
2483
2490
position_ids ,
@@ -2566,7 +2573,7 @@ def execute_model(
2566
2573
else :
2567
2574
if dummy_decode_input_data_across_dp is not None :
2568
2575
htorch .core .mark_step ()
2569
- _ , dummy_logits_device = self ._execute_model_generic (
2576
+ _ , _ , dummy_logits_device = self ._execute_model_generic (
2570
2577
dummy_decode_input_data_across_dp .token_ids ,
2571
2578
dummy_decode_input_data_across_dp .position_ids ,
2572
2579
dummy_decode_input_data_across_dp .attn_metadata ,
0 commit comments