@@ -1423,6 +1423,10 @@ def _form_prefill_batch(self, contents):
1423
1423
target_bs , target_seq , target_blocks = self ._get_prompt_bucketing_fn ()(
1424
1424
query_lens , num_context_blocks )
1425
1425
1426
+ target_bs += self .get_dp_padding (target_bs )
1427
+ target_seq += self .get_dp_padding (target_seq )
1428
+ target_blocks += self .get_dp_padding (target_blocks )
1429
+
1426
1430
# NOTE: If model does not support multimodal inputs, we pad here.
1427
1431
# For models with multimodal support, we may want to get embeddings
1428
1432
# for the valid tokens before padding.
@@ -1544,15 +1548,23 @@ def _create_dummy_prefill_batch_contents(
1544
1548
return outputs
1545
1549
1546
1550
def _prepare_prefill_inputs (
1547
- self , num_prefills , num_decodes ,
1548
- num_scheduled_tokens : list [ int ] ) -> tuple [PrefillInputData , int ]:
1551
+ self , num_prefills , num_decodes , num_scheduled_tokens : list [ int ]
1552
+ ) -> tuple [PrefillInputData , Optional [ PrefillInputData ] ]:
1549
1553
all_batch_contents , num_pad_across_dp = self ._extract_prefill_batch_contents (
1550
1554
num_prefills , num_decodes , num_scheduled_tokens )
1551
1555
all_batches = [
1552
1556
self ._form_prefill_batch (bc ) for bc in all_batch_contents
1553
1557
]
1554
1558
merge_contents (all_batches [0 ], * all_batches [1 :])
1555
- return all_batches [0 ], num_pad_across_dp
1559
+
1560
+ dummy_prefill_input_batches = None
1561
+ if num_pad_across_dp > 0 :
1562
+ dummy_prefill_input_batches = self ._create_dummy_prefill_batch_contents (
1563
+ num_pad_across_dp )
1564
+ merge_contents (dummy_prefill_input_batches [0 ],
1565
+ * dummy_prefill_input_batches [1 :])
1566
+ return all_batches [0 ], dummy_prefill_input_batches [
1567
+ 0 ] if dummy_prefill_input_batches else None
1556
1568
1557
1569
def _create_decode_input_data (
1558
1570
self , num_decodes , num_scheduled_tokens , context_lens ,
@@ -1692,8 +1704,8 @@ def _create_decode_input_data(
1692
1704
)), num_pad_across_dp
1693
1705
1694
1706
def _prepare_decode_inputs (
1695
- self , num_decodes ,
1696
- num_scheduled_tokens ) -> tuple [DecodeInputData , int ]:
1707
+ self , num_decodes , num_scheduled_tokens
1708
+ ) -> tuple [DecodeInputData , Optional [ DecodeInputData ] ]:
1697
1709
# Decodes run as one single padded batch with shape [batch, 1]
1698
1710
#
1699
1711
# We need to set _PAD_SLOT_ID for the padding tokens in the
@@ -1703,35 +1715,38 @@ def _prepare_decode_inputs(
1703
1715
1704
1716
num_pad_across_dp = self .get_dp_padding (num_decodes )
1705
1717
if num_decodes == 0 :
1706
- return DecodeInputData ( num_decodes = 0 ), num_pad_across_dp
1707
- # BLOCK_TABLE [batch, max_num_blocks_per_req]
1708
- context_lens = self . input_batch . num_computed_tokens_cpu [: num_decodes ]
1709
- block_table_cpu_tensor = self . input_batch . block_table [
1710
- 0 ]. get_cpu_tensor ()
1718
+ if num_pad_across_dp > 0 :
1719
+ dummy_decode_input_data = self . _create_dummy_decode_input_data (
1720
+ )
1721
+ return DecodeInputData ( num_decodes = 0 ), dummy_decode_input_data
1722
+ return DecodeInputData ( num_decodes = 0 ), None
1711
1723
return self ._create_decode_input_data (
1712
- num_decodes , num_scheduled_tokens , context_lens ,
1713
- block_table_cpu_tensor , self .input_batch .num_computed_tokens_cpu ,
1724
+ num_decodes , num_scheduled_tokens ,
1725
+ self .input_batch .num_computed_tokens_cpu [:num_decodes ],
1726
+ self .input_batch .block_table [0 ].get_cpu_tensor (),
1727
+ self .input_batch .num_computed_tokens_cpu ,
1714
1728
self .input_batch .token_ids_cpu )
1715
1729
1716
1730
def _create_dummy_decode_input_data (self ) -> DecodeInputData :
1717
1731
# create dummy decode input data with batch size 1
1732
+ num_dummy_decodes = 1
1733
+ num_dummy_scheduled_tokens = [1 ]
1718
1734
context_lens = [128 ]
1719
1735
block_table_cpu_tensor = torch .zeros ([self ._PAD_BLOCK_ID ],
1720
1736
dtype = torch .int32 ).reshape (1 , - 1 )
1721
1737
num_computed_tokens_cpu = np .array ([128 ], dtype = np .int32 )
1722
1738
token_ids = np .array (list (int (i ) for i in range (context_lens [0 ])))
1723
1739
1724
- return self ._create_decode_input_data (1 , [1 ], context_lens ,
1725
- block_table_cpu_tensor ,
1726
- num_computed_tokens_cpu ,
1727
- token_ids )[0 ]
1740
+ return self ._create_decode_input_data (
1741
+ num_dummy_decodes , num_dummy_scheduled_tokens , context_lens ,
1742
+ block_table_cpu_tensor , num_computed_tokens_cpu , token_ids )[0 ]
1728
1743
1729
1744
def _prepare_inputs (
1730
1745
self ,
1731
1746
scheduler_output : "SchedulerOutput" ,
1732
1747
num_prefills ,
1733
1748
num_decodes ,
1734
- ) -> tuple [ PrefillInputData , Optional [ DecodeInputData ]] :
1749
+ ):
1735
1750
1736
1751
total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
1737
1752
assert total_num_scheduled_tokens > 0
@@ -2087,8 +2102,10 @@ def execute_model(
2087
2102
with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
2088
2103
prefill_input_data , decode_input_data = self ._prepare_inputs (
2089
2104
scheduler_output , num_prefills , num_decodes )
2090
- prefill_data , num_pad_prefill_batch_across_dp = prefill_input_data
2091
- decode_data , num_pad_decode_batch_across_dp = decode_input_data
2105
+ prefill_data , dummy_prefill_input_data_batches_across_dp = prefill_input_data
2106
+ num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len (
2107
+ dummy_prefill_input_data_batches_across_dp .request_ids )
2108
+ decode_data , dummy_decode_input_data_across_dp = decode_input_data
2092
2109
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
2093
2110
# later.
2094
2111
prefill_sampled_token_ids = []
@@ -2154,7 +2171,6 @@ def execute_model(
2154
2171
model_mm_kwargs = model_mm_kwargs ,
2155
2172
warmup_mode = warmup_mode )
2156
2173
htorch .core .mark_step ()
2157
-
2158
2174
# Skip separate sampling for structured output
2159
2175
if structured_output :
2160
2176
logits_prompt .append (logits_device )
@@ -2191,17 +2207,19 @@ def execute_model(
2191
2207
2192
2208
else :
2193
2209
if num_pad_prefill_batch_across_dp > 0 :
2194
- htorch .core .mark_step ()
2195
- dummy_prefill_input_data_list = self ._create_dummy_prefill_batch_contents (
2196
- num_pad_prefill_batch_across_dp )
2197
- for dummy_prefill_input_data in dummy_prefill_input_data_list :
2210
+ for idx , (
2211
+ req_id , prompt_len , token_ids , position_ids ,
2212
+ attn_metadata , logits_indices ,
2213
+ logits_requests ) in enumerate (
2214
+ zip (* shallow_tuple (
2215
+ dummy_prefill_input_data_batches_across_dp ))):
2198
2216
htorch .core .mark_step ()
2199
2217
_ , dummy_logits_device = \
2200
2218
self ._execute_model_generic (
2201
- dummy_prefill_input_data . token_ids [ 0 ] ,
2202
- dummy_prefill_input_data . position_ids [ 0 ] ,
2203
- dummy_prefill_input_data . attn_metadata [ 0 ] ,
2204
- dummy_prefill_input_data . logits_indices [ 0 ] ,
2219
+ token_ids ,
2220
+ position_ids ,
2221
+ attn_metadata ,
2222
+ logits_indices ,
2205
2223
self .kv_caches ,
2206
2224
warmup_mode = warmup_mode )
2207
2225
htorch .core .mark_step ()
@@ -2255,15 +2273,13 @@ def execute_model(
2255
2273
is_prompt = False )
2256
2274
self .profiler .record_counter (self .event_start , counters )
2257
2275
else :
2258
- if num_pad_decode_batch_across_dp > 0 :
2259
- dummy_decode_input_data = self ._create_dummy_decode_input_data (
2260
- )
2276
+ if dummy_decode_input_data_across_dp is not None :
2261
2277
htorch .core .mark_step ()
2262
2278
_ , dummy_logits_device = self ._execute_model_generic (
2263
- dummy_decode_input_data .token_ids ,
2264
- dummy_decode_input_data .position_ids ,
2265
- dummy_decode_input_data .attn_metadata ,
2266
- dummy_decode_input_data .logits_indices ,
2279
+ dummy_decode_input_data_across_dp .token_ids ,
2280
+ dummy_decode_input_data_across_dp .position_ids ,
2281
+ dummy_decode_input_data_across_dp .attn_metadata ,
2282
+ dummy_decode_input_data_across_dp .logits_indices ,
2267
2283
self .kv_caches ,
2268
2284
warmup_mode = warmup_mode )
2269
2285
htorch .core .mark_step ()
0 commit comments