@@ -1152,9 +1152,9 @@ def _form_prefill_batch(self, contents):
1152
1152
query_lens , num_context_blocks )
1153
1153
1154
1154
# dp aware padding
1155
- target_bs = self .get_dp_padding (target_bs )
1156
- target_seq = self .get_dp_padding (target_seq )
1157
- target_blocks = self .get_dp_padding (target_blocks )
1155
+ target_bs + = self .get_dp_padding (target_bs )
1156
+ target_seq + = self .get_dp_padding (target_seq )
1157
+ target_blocks + = self .get_dp_padding (target_blocks )
1158
1158
1159
1159
token_ids = self ._align_and_pad (contents .token_ids ,
1160
1160
(target_bs , target_seq ),
@@ -1273,7 +1273,7 @@ def _prepare_decode_inputs(self, num_decodes,
1273
1273
num_decodes , sum (num_blocks ))[0 ]
1274
1274
1275
1275
# dp aware padding
1276
- padded_batch_size = self .get_dp_padding (padded_batch_size )
1276
+ padded_batch_size + = self .get_dp_padding (padded_batch_size )
1277
1277
1278
1278
block_tables_list = []
1279
1279
for i , n in enumerate (num_blocks ):
@@ -1427,7 +1427,7 @@ def get_dp_padding(self,
1427
1427
1428
1428
if dp_size == 1 or self .vllm_config .model_config .enforce_eager :
1429
1429
# Early exit.
1430
- return 0 , None
1430
+ return 0
1431
1431
1432
1432
num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
1433
1433
num_tokens , dp_size , dp_rank )
@@ -1436,7 +1436,7 @@ def get_dp_padding(self,
1436
1436
# dp_size,
1437
1437
# device="cpu",
1438
1438
# dtype=torch.int32).item()
1439
- return max_tokens_across_dp_cpu
1439
+ return max_tokens_across_dp_cpu - num_tokens
1440
1440
1441
1441
def _execute_model_generic (self ,
1442
1442
token_ids ,
@@ -1643,11 +1643,9 @@ def apply_grammar_bitmask(
1643
1643
logits_cpu .to (self .device , non_blocking = True ).to (logits .dtype ))
1644
1644
1645
1645
@torch .inference_mode ()
1646
- def execute_model (
1647
- self ,
1648
- scheduler_output : "SchedulerOutput" ,
1649
- warmup_mode = False ,
1650
- ) -> ModelRunnerOutput :
1646
+ def execute_model (self ,
1647
+ scheduler_output : "SchedulerOutput" ,
1648
+ warmup_mode = False ) -> ModelRunnerOutput :
1651
1649
# NOTE(kzawora): Since scheduler doesn't differentiate between prefills
1652
1650
# and decodes, we must handle mixed batches. In _update_states we make
1653
1651
# sure that first self.input_batch.num_decodes requests are decodes,
@@ -1751,8 +1749,12 @@ def execute_model(
1751
1749
htorch .core .mark_step ()
1752
1750
prefill_hidden_states_ts , logits_device = \
1753
1751
self ._execute_model_generic (
1754
- token_ids , position_ids , attn_metadata , logits_indices ,
1755
- self .kv_caches , warmup_mode = warmup_mode )
1752
+ token_ids ,
1753
+ position_ids ,
1754
+ attn_metadata ,
1755
+ logits_indices ,
1756
+ self .kv_caches ,
1757
+ warmup_mode = warmup_mode )
1756
1758
htorch .core .mark_step ()
1757
1759
# Skip separate sampling for structured output
1758
1760
if structured_output :
@@ -2477,7 +2479,6 @@ def __del__(self):
2477
2479
2478
2480
@torch .inference_mode ()
2479
2481
def profile_run (self ) -> None :
2480
- return
2481
2482
"""Profile to measure peak memory during forward pass."""
2482
2483
2483
2484
# use an empty tensor instead of `None`` to force Dynamo to pass
@@ -2497,10 +2498,14 @@ def profile_run(self) -> None:
2497
2498
if max_seq_len % self .block_size != 0 :
2498
2499
max_seq_len = ((max_seq_len + self .block_size - 1 ) //
2499
2500
self .block_size ) * self .block_size
2501
+ max_seq_len = min (max_seq_len , self .max_model_len )
2500
2502
2501
- prompt_cfg = (max_prefill_batch_size , max_seq_len , 0 )
2502
- decode_cfg = None
2503
+ # different DP engine may have different config
2504
+ max_seq_len += self .get_dp_padding (max_seq_len )
2505
+ max_prefill_batch_size += self .get_dp_padding (max_prefill_batch_size )
2503
2506
2507
+ prompt_cfg = (max_prefill_batch_size , max_seq_len - 1 , 0 )
2508
+ decode_cfg = None
2504
2509
self ._execute_dummy_scenario (prompt_cfg , decode_cfg )
2505
2510
2506
2511
# # Run empty prefill forwards - prefill max batch and prefill max seq
0 commit comments