@@ -1150,6 +1150,12 @@ def _form_prefill_batch(self, contents):
1150
1150
1151
1151
target_bs , target_seq , target_blocks = self ._get_prompt_bucketing_fn ()(
1152
1152
query_lens , num_context_blocks )
1153
+
1154
+ # dp aware padding
1155
+ target_bs = self .get_dp_padding (target_bs )
1156
+ target_seq = self .get_dp_padding (target_seq )
1157
+ target_blocks = self .get_dp_padding (target_blocks )
1158
+
1153
1159
token_ids = self ._align_and_pad (contents .token_ids ,
1154
1160
(target_bs , target_seq ),
1155
1161
itertools .repeat (- 1 ))
@@ -1266,6 +1272,9 @@ def _prepare_decode_inputs(self, num_decodes,
1266
1272
padded_batch_size = self .bucketing_manager .find_decode_bucket (
1267
1273
num_decodes , sum (num_blocks ))[0 ]
1268
1274
1275
+ # dp aware padding
1276
+ padded_batch_size = self .get_dp_padding (padded_batch_size )
1277
+
1269
1278
block_tables_list = []
1270
1279
for i , n in enumerate (num_blocks ):
1271
1280
seq_block_table = block_table_cpu_tensor [i , :n ].tolist ()
@@ -1365,8 +1374,6 @@ def _prepare_inputs(
1365
1374
total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
1366
1375
assert total_num_scheduled_tokens > 0
1367
1376
1368
- # TODO wuxun: consider dp aware padding for bs, block bucket, etc.
1369
-
1370
1377
num_reqs = num_prefills + num_decodes
1371
1378
1372
1379
# Get the number of scheduled tokens for each request.
@@ -1406,7 +1413,6 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
1406
1413
"Configuration: (%s, %s, %s, %s) was not warmed-up!" , phase ,
1407
1414
batch_size , seq_len , num_blocks )
1408
1415
1409
- # TODO wuxun: dp padding for prefill/decode inputs
1410
1416
def get_dp_padding (self ,
1411
1417
num_tokens : int ) -> tuple [int , Optional [torch .Tensor ]]:
1412
1418
dp_size = self .vllm_config .parallel_config .data_parallel_size
@@ -1426,11 +1432,11 @@ def get_dp_padding(self,
1426
1432
num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
1427
1433
num_tokens , dp_size , dp_rank )
1428
1434
max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp ).item ()
1429
- num_tokens_after_padding = torch .tensor ([max_tokens_across_dp_cpu ] *
1430
- dp_size ,
1431
- device = "cpu" ,
1432
- dtype = torch .int32 )
1433
- return max_tokens_across_dp_cpu - num_tokens , num_tokens_after_padding
1435
+ # num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
1436
+ # dp_size,
1437
+ # device="cpu",
1438
+ # dtype=torch.int32).item( )
1439
+ return max_tokens_across_dp_cpu
1434
1440
1435
1441
def _execute_model_generic (self ,
1436
1442
token_ids ,
@@ -2481,36 +2487,40 @@ def profile_run(self) -> None:
2481
2487
# it is important to create tensors inside the loop, rather than
2482
2488
# multiplying the list, to avoid Dynamo from treating them as
2483
2489
# tensor aliasing.
2484
- num_layers = self .model_config .get_num_layers (self .parallel_config )
2485
- kv_caches = [None ] * num_layers
2486
-
2487
- # Run empty prefill forwards - prefill max batch and prefill max seq
2488
- self .warmup_scenario (batch_size = 1 ,
2489
- seq_or_block = self .max_model_len ,
2490
- is_prompt = True ,
2491
- kv_caches = kv_caches )
2492
- max_seq_len = math .ceil (
2493
- (self .max_num_tokens // self .max_prefill_batch_size ) /
2494
- self .block_size ) * self .block_size
2495
- self .warmup_scenario (batch_size = self .max_prefill_batch_size ,
2496
- seq_or_block = max_seq_len ,
2497
- is_prompt = True ,
2498
- kv_caches = kv_caches )
2490
+ # num_layers = self.model_config.get_num_layers(self.parallel_config)
2491
+ # kv_caches = [None] * num_layers
2492
+
2493
+ max_num_batched_tokens = self .max_num_tokens
2494
+ max_prefill_batch_size = self .max_prefill_batch_size
2495
+ max_seq_len = (max_num_batched_tokens + max_prefill_batch_size -
2496
+ 1 ) // max_prefill_batch_size
2497
+ if max_seq_len % self .block_size != 0 :
2498
+ max_seq_len = ((max_seq_len + self .block_size - 1 ) //
2499
+ self .block_size ) * self .block_size
2500
+
2501
+ prompt_cfg = (max_prefill_batch_size , max_seq_len , 0 )
2502
+ decode_cfg = None
2503
+
2504
+ self ._execute_dummy_scenario (prompt_cfg , decode_cfg )
2505
+
2506
+ # # Run empty prefill forwards - prefill max batch and prefill max seq
2507
+ # self.warmup_scenario(batch_size=1,
2508
+ # seq_or_block=self.max_model_len,
2509
+ # is_prompt=True,
2510
+ # kv_caches=kv_caches)
2511
+ # max_seq_len = math.ceil(
2512
+ # (self.max_num_tokens // self.max_prefill_batch_size) /
2513
+ # self.block_size) * self.block_size
2514
+ # self.warmup_scenario(batch_size=self.max_prefill_batch_size,
2515
+ # seq_or_block=max_seq_len,
2516
+ # is_prompt=True,
2517
+ # kv_caches=kv_caches)
2499
2518
2500
2519
def _dummy_run (self , max_num_batched_tokens : int ) -> None :
2501
- # TODO wuxun: dummy run implementation
2502
2520
assert max_num_batched_tokens == 1
2503
- # self.warmup_scenario(max_num_batched_tokens,
2504
- # 1,
2505
- # 1,
2506
- # is_prompt=False,
2507
- # kv_caches=None,
2508
- # num_iters=1,
2509
- # is_pt_profiler_run=False,
2510
- # align_worker=True,
2511
- # is_dummy_run=True)
2512
- prompt_cfg = 1 , 1 , 0
2513
- decode_cfg = None
2521
+ prompt_cfg = None
2522
+ decode_cfg = 1 , 1
2523
+ # add dummy decode run
2514
2524
self ._execute_dummy_scenario (prompt_cfg , decode_cfg )
2515
2525
return
2516
2526
0 commit comments