@@ -1354,11 +1354,8 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
1354
1354
# no real prefill batches
1355
1355
num_prefill_batches = 0
1356
1356
1357
- num_pad = self .get_dp_padding (num_prefill_batches )
1358
- if num_pad > 0 :
1359
- for _ in range (num_pad ):
1360
- all_batch_contents .append (BatchContents ())
1361
- return all_batch_contents
1357
+ num_pad_across_dp = self .get_dp_padding (num_prefill_batches )
1358
+ return all_batch_contents , num_pad_across_dp
1362
1359
1363
1360
def _make_attn_bias (self , context_groups , token_groups ):
1364
1361
dtype = self .dtype
@@ -1426,11 +1423,6 @@ def _form_prefill_batch(self, contents):
1426
1423
target_bs , target_seq , target_blocks = self ._get_prompt_bucketing_fn ()(
1427
1424
query_lens , num_context_blocks )
1428
1425
1429
- # dp aware padding
1430
- target_bs += self .get_dp_padding (target_bs )
1431
- target_seq += self .get_dp_padding (target_seq )
1432
- target_blocks += self .get_dp_padding (target_blocks )
1433
-
1434
1426
# NOTE: If model does not support multimodal inputs, we pad here.
1435
1427
# For models with multimodal support, we may want to get embeddings
1436
1428
# for the valid tokens before padding.
@@ -1523,33 +1515,49 @@ def _form_prefill_batch(self, contents):
1523
1515
logits_indices = [logits_indices ],
1524
1516
logits_requests = [logits_requests ])
1525
1517
1518
+ def _create_dummy_prefill_batch_contents (
1519
+ self , num_prefills : int ) -> list [PrefillInputData ]:
1520
+ req_id = - 1
1521
+ context_len = 0
1522
+ query_len = 128
1523
+ prompt_tokens = 128
1524
+ token_ids = list (int (i ) for i in range (prompt_tokens ))
1525
+ num_blocks = round_up (context_len + query_len ,
1526
+ self .block_size ) // self .block_size
1527
+ blocks = [0 ] * num_blocks
1528
+ num_output_logits = context_len + query_len - prompt_tokens + 1
1529
+ logits_positions = list (range (query_len - num_output_logits ,
1530
+ query_len ))
1531
+
1532
+ new_batch_contents = BatchContents (
1533
+ req_ids = [req_id ],
1534
+ token_ids = [token_ids ],
1535
+ context_lens = [context_len ],
1536
+ blocks = [blocks ],
1537
+ logits_positions = [logits_positions ],
1538
+ )
1539
+
1540
+ outputs = [
1541
+ self ._form_prefill_batch (new_batch_contents )
1542
+ for _ in range (num_prefills )
1543
+ ]
1544
+ return outputs
1545
+
1526
1546
def _prepare_prefill_inputs (
1527
1547
self , num_prefills , num_decodes ,
1528
- num_scheduled_tokens : list [int ]) -> PrefillInputData :
1529
- all_batch_contents = self ._extract_prefill_batch_contents (
1548
+ num_scheduled_tokens : list [int ]) -> tuple [ PrefillInputData , int ] :
1549
+ all_batch_contents , num_pad_across_dp = self ._extract_prefill_batch_contents (
1530
1550
num_prefills , num_decodes , num_scheduled_tokens )
1531
1551
all_batches = [
1532
1552
self ._form_prefill_batch (bc ) for bc in all_batch_contents
1533
1553
]
1534
1554
merge_contents (all_batches [0 ], * all_batches [1 :])
1535
- return all_batches [0 ]
1536
-
1537
- def _prepare_decode_inputs (self , num_decodes ,
1538
- num_scheduled_tokens ) -> DecodeInputData :
1539
- # Decodes run as one single padded batch with shape [batch, 1]
1540
- #
1541
- # We need to set _PAD_SLOT_ID for the padding tokens in the
1542
- # slot_mapping, such that the attention KV cache insertion
1543
- # logic knows to ignore those indicies. Otherwise, the
1544
- # padding data can be dummy since we have a causal mask.
1545
-
1546
- block_table_cpu_tensor = self .input_batch .block_table [
1547
- 0 ].get_cpu_tensor ()
1548
- if num_decodes == 0 :
1549
- return DecodeInputData (num_decodes = 0 )
1550
- # BLOCK_TABLE [batch, max_num_blocks_per_req]
1551
- context_lens = self .input_batch .num_computed_tokens_cpu [:num_decodes ]
1555
+ return all_batches [0 ], num_pad_across_dp
1552
1556
1557
+ def _create_decode_input_data (
1558
+ self , num_decodes , num_scheduled_tokens , context_lens ,
1559
+ block_table_cpu_tensor , num_computed_tokens_cpu ,
1560
+ token_ids_cpu ) -> tuple [DecodeInputData , int ]:
1553
1561
# NOTE(kzawora): the +1 is what causes this entire thing to work,
1554
1562
# as in the paged attention, we don't fetch just the context from cache,
1555
1563
# but also kvs for the current token
@@ -1561,8 +1569,9 @@ def _prepare_decode_inputs(self, num_decodes,
1561
1569
padded_batch_size = self .bucketing_manager .find_decode_bucket (
1562
1570
num_decodes , sum (num_blocks ))[0 ]
1563
1571
1564
- # # dp aware padding
1565
- padded_batch_size += self .get_dp_padding (padded_batch_size )
1572
+ # dp aware padding
1573
+ num_pad_across_dp = self .get_dp_padding (padded_batch_size )
1574
+ padded_batch_size += num_pad_across_dp
1566
1575
1567
1576
block_tables_list = []
1568
1577
for i , n in enumerate (num_blocks ):
@@ -1574,8 +1583,7 @@ def _prepare_decode_inputs(self, num_decodes,
1574
1583
# We slice at the end, since we use the positions for gathering.
1575
1584
positions = torch .zeros ((padded_batch_size , 1 ), dtype = torch .int32 )
1576
1585
positions [:num_decodes ] = torch .from_numpy (
1577
- self .input_batch .num_computed_tokens_cpu .reshape (- 1 ,
1578
- 1 )[:num_decodes ])
1586
+ num_computed_tokens_cpu .reshape (- 1 , 1 )[:num_decodes ])
1579
1587
positions = positions [:padded_batch_size ]
1580
1588
1581
1589
padded_index = torch .zeros ((padded_batch_size , 1 ), dtype = torch .int64 )
@@ -1613,11 +1621,8 @@ def _prepare_decode_inputs(self, num_decodes,
1613
1621
1614
1622
# TOKEN_IDS. [batch, 1]
1615
1623
token_ids = torch .zeros ((padded_batch_size , 1 ), dtype = torch .int32 )
1616
- token_ids [:num_decodes ] = torch .gather (input = torch .from_numpy (
1617
- self .input_batch .token_ids_cpu ),
1618
- dim = 1 ,
1619
- index = index )
1620
-
1624
+ token_ids [:num_decodes ] = torch .gather (
1625
+ input = torch .from_numpy (token_ids_cpu ), dim = 1 , index = index )
1621
1626
# SLOT_MAPPING [batch, 1]
1622
1627
# The "slot" is the "physical index" of a token in the KV cache.
1623
1628
# Look up the block_idx in the block table (logical<>physical map)
@@ -1684,7 +1689,42 @@ def _prepare_decode_inputs(self, num_decodes,
1684
1689
num_decode_tokens = num_decode_tokens_device ,
1685
1690
slot_mapping = slot_mapping_device ,
1686
1691
block_size = self .block_size ,
1687
- ))
1692
+ )), num_pad_across_dp
1693
+
1694
+ def _prepare_decode_inputs (
1695
+ self , num_decodes ,
1696
+ num_scheduled_tokens ) -> tuple [DecodeInputData , int ]:
1697
+ # Decodes run as one single padded batch with shape [batch, 1]
1698
+ #
1699
+ # We need to set _PAD_SLOT_ID for the padding tokens in the
1700
+ # slot_mapping, such that the attention KV cache insertion
1701
+ # logic knows to ignore those indicies. Otherwise, the
1702
+ # padding data can be dummy since we have a causal mask.
1703
+
1704
+ num_pad_across_dp = self .get_dp_padding (num_decodes )
1705
+ if num_decodes == 0 :
1706
+ return DecodeInputData (num_decodes = 0 ), num_pad_across_dp
1707
+ # BLOCK_TABLE [batch, max_num_blocks_per_req]
1708
+ context_lens = self .input_batch .num_computed_tokens_cpu [:num_decodes ]
1709
+ block_table_cpu_tensor = self .input_batch .block_table [
1710
+ 0 ].get_cpu_tensor ()
1711
+ return self ._create_decode_input_data (
1712
+ num_decodes , num_scheduled_tokens , context_lens ,
1713
+ block_table_cpu_tensor , self .input_batch .num_computed_tokens_cpu ,
1714
+ self .input_batch .token_ids_cpu )
1715
+
1716
+ def _create_dummy_decode_input_data (self ) -> DecodeInputData :
1717
+ # create dummy decode input data with batch size 1
1718
+ context_lens = [128 ]
1719
+ block_table_cpu_tensor = torch .zeros ([self ._PAD_BLOCK_ID ],
1720
+ dtype = torch .int32 ).reshape (1 , - 1 )
1721
+ num_computed_tokens_cpu = np .array ([128 ], dtype = np .int32 )
1722
+ token_ids = np .array (list (int (i ) for i in range (context_lens [0 ])))
1723
+
1724
+ return self ._create_decode_input_data (1 , [1 ], context_lens ,
1725
+ block_table_cpu_tensor ,
1726
+ num_computed_tokens_cpu ,
1727
+ token_ids )[0 ]
1688
1728
1689
1729
def _prepare_inputs (
1690
1730
self ,
@@ -1740,18 +1780,7 @@ def get_dp_padding(self,
1740
1780
dp_size = self .vllm_config .parallel_config .data_parallel_size
1741
1781
dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1742
1782
1743
- # For DP: Don't pad when setting enforce_eager.
1744
- # This lets us set enforce_eager on the prefiller in a P/D setup and
1745
- # still use CUDA graphs (enabled by this padding) on the decoder.
1746
- #
1747
- # TODO(tms) : There are many cases where padding is enabled for
1748
- # prefills, causing unnecessary and excessive padding of activations.
1749
-
1750
- # skip padding for non PD disagg case to avoid padding on prefill batch
1751
- # size and decode batch size
1752
- if dp_size == 1 or self .vllm_config .model_config .enforce_eager or (
1753
- self .vllm_config .kv_transfer_config is None
1754
- or self .vllm_config .kv_transfer_config .kv_connector is None ):
1783
+ if dp_size == 1 :
1755
1784
return 0
1756
1785
1757
1786
num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
@@ -1768,7 +1797,6 @@ def _execute_model_generic(self,
1768
1797
warmup_mode = False ,
1769
1798
inputs_embeds = None ,
1770
1799
model_mm_kwargs = None ):
1771
-
1772
1800
# FORWARD.
1773
1801
batch_size = token_ids .size (0 )
1774
1802
seq_len = self ._seq_len (attn_metadata )
@@ -2057,8 +2085,10 @@ def execute_model(
2057
2085
num_prefills = len (pd_info .prompt_req_ids )
2058
2086
num_reqs = num_decodes + num_prefills
2059
2087
with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
2060
- prefill_data , decode_data = self ._prepare_inputs (
2088
+ prefill_input_data , decode_input_data = self ._prepare_inputs (
2061
2089
scheduler_output , num_prefills , num_decodes )
2090
+ prefill_data , num_pad_prefill_batch_across_dp = prefill_input_data
2091
+ decode_data , num_pad_decode_batch_across_dp = decode_input_data
2062
2092
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
2063
2093
# later.
2064
2094
prefill_sampled_token_ids = []
@@ -2124,6 +2154,7 @@ def execute_model(
2124
2154
model_mm_kwargs = model_mm_kwargs ,
2125
2155
warmup_mode = warmup_mode )
2126
2156
htorch .core .mark_step ()
2157
+
2127
2158
# Skip separate sampling for structured output
2128
2159
if structured_output :
2129
2160
logits_prompt .append (logits_device )
@@ -2154,9 +2185,27 @@ def execute_model(
2154
2185
prompt_batch_idx = idx ,
2155
2186
is_prompt = True )
2156
2187
self .profiler .record_counter (self .event_start , counters )
2188
+
2157
2189
if self .is_driver_worker and self .profiler .enabled :
2158
2190
self .profiler_counter_helper .reset_prompt_seq_stats ()
2159
2191
2192
+ else :
2193
+ if num_pad_prefill_batch_across_dp > 0 :
2194
+ htorch .core .mark_step ()
2195
+ dummy_prefill_input_data_list = self ._create_dummy_prefill_batch_contents (
2196
+ num_pad_prefill_batch_across_dp )
2197
+ for dummy_prefill_input_data in dummy_prefill_input_data_list :
2198
+ htorch .core .mark_step ()
2199
+ _ , dummy_logits_device = \
2200
+ self ._execute_model_generic (
2201
+ dummy_prefill_input_data .token_ids [0 ],
2202
+ dummy_prefill_input_data .position_ids [0 ],
2203
+ dummy_prefill_input_data .attn_metadata [0 ],
2204
+ dummy_prefill_input_data .logits_indices [0 ],
2205
+ self .kv_caches ,
2206
+ warmup_mode = warmup_mode )
2207
+ htorch .core .mark_step ()
2208
+
2160
2209
######################### DECODES #########################
2161
2210
# Decodes run as one single batch with [padded_decode_bs, 1]
2162
2211
if num_decodes > 0 :
@@ -2205,6 +2254,19 @@ def execute_model(
2205
2254
prompt_batch_idx = None ,
2206
2255
is_prompt = False )
2207
2256
self .profiler .record_counter (self .event_start , counters )
2257
+ else :
2258
+ if num_pad_decode_batch_across_dp > 0 :
2259
+ dummy_decode_input_data = self ._create_dummy_decode_input_data (
2260
+ )
2261
+ htorch .core .mark_step ()
2262
+ _ , dummy_logits_device = self ._execute_model_generic (
2263
+ dummy_decode_input_data .token_ids ,
2264
+ dummy_decode_input_data .position_ids ,
2265
+ dummy_decode_input_data .attn_metadata ,
2266
+ dummy_decode_input_data .logits_indices ,
2267
+ self .kv_caches ,
2268
+ warmup_mode = warmup_mode )
2269
+ htorch .core .mark_step ()
2208
2270
2209
2271
if structured_output :
2210
2272
# Scheduler places cached before prompt
@@ -2315,6 +2377,7 @@ def execute_model(
2315
2377
prompt_logprobs_dict = prompt_logprobs_dict , # type: ignore[arg-type]
2316
2378
pooler_output = [],
2317
2379
)
2380
+
2318
2381
return model_runner_output
2319
2382
2320
2383
def load_model (self ) -> None :
@@ -2735,8 +2798,8 @@ def __del__(self):
2735
2798
2736
2799
@torch .inference_mode ()
2737
2800
def profile_run (self ) -> None :
2801
+ return
2738
2802
"""Profile to measure peak memory during forward pass."""
2739
-
2740
2803
# use an empty tensor instead of `None`` to force Dynamo to pass
2741
2804
# it by reference, rather by specializing on the value `None`.
2742
2805
# the `dtype` argument does not matter, and we use `float32` as
@@ -2750,7 +2813,6 @@ def profile_run(self) -> None:
2750
2813
max_seq_len = math .ceil (
2751
2814
(self .max_num_tokens // self .max_prefill_batch_size ) /
2752
2815
self .block_size ) * self .block_size
2753
- max_seq_len = min (max_seq_len , self .max_model_len )
2754
2816
self ._execute_dummy_scenario (
2755
2817
(self .max_prefill_batch_size , max_seq_len , 0 ), None )
2756
2818
0 commit comments