@@ -348,13 +348,38 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
348
348
torch ._logging .set_logs (
349
349
recompiles = envs_ascend .VLLM_ASCEND_TRACE_RECOMPILES )
350
350
351
+ self .check_batch_sizes_consistency ()
351
352
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
352
353
self .in_profile_run = False
353
354
354
355
# kv role
355
356
self .is_kv_producer = False
357
+ self .is_kv_consumer = False
356
358
if vllm_config .kv_transfer_config is not None :
357
359
self .is_kv_producer = vllm_config .kv_transfer_config .is_kv_producer
360
+ self .is_kv_consumer = vllm_config .kv_transfer_config .is_kv_consumer
361
+
362
+ def check_batch_sizes_consistency (self ) -> None :
363
+ if not dist .is_initialized ():
364
+ return
365
+
366
+ local = torch .tensor (self .torchair_graph_batch_sizes ,
367
+ device = "cpu" ,
368
+ dtype = torch .int32 )
369
+ gathered_graph_batch_size = local .clone ()
370
+ dist .all_reduce (gathered_graph_batch_size ,
371
+ group = get_dp_group ().cpu_group )
372
+ expected = local * self .dp_size
373
+
374
+ if not torch .equal (gathered_graph_batch_size , expected ):
375
+ diff_idxs = (gathered_graph_batch_size != expected ).nonzero (
376
+ as_tuple = False ).flatten ().tolist ()
377
+ raise AssertionError (
378
+ f"[Graph BatchSize Mismatch] Found mismatches at indices { diff_idxs } .\n "
379
+ f"Local (rank { self .dp_rank } ): { local .tolist ()} \n "
380
+ f"Sum over ranks: { gathered_graph_batch_size .tolist ()} \n "
381
+ f"Expected if all equal: { [v * self .dp_size for v in local .tolist ()]} "
382
+ )
358
383
359
384
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
360
385
"""Update the cached states and the persistent batch with the scheduler
@@ -570,44 +595,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570
595
self .input_batch .refresh_sampling_metadata ()
571
596
572
597
def _get_forward_metadata_across_dp (
573
- self ,
574
- maybe_padded_num_tokens : int ,
575
- num_tokens : int ,
576
- with_prefill : bool ,
577
- enable_dbo : bool = False ,
598
+ self , num_tokens : int , with_prefill : bool ,
599
+ enable_dbo : bool ) -> tuple [torch .Tensor , bool , bool ]:
600
+
601
+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
602
+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 ,
603
+ dtype = torch .int32 ,
604
+ device = "cpu" )
605
+ num_tokens_across_dp [self .dp_rank ] = num_tokens
606
+ num_tokens_across_dp [- 2 ] = int (with_prefill )
607
+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
608
+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
609
+ with_prefill = bool (num_tokens_across_dp [- 2 ])
610
+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
611
+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
612
+ return num_tokens_across_dp , with_prefill , enable_dbo
613
+
614
+ def _get_forward_metadata_across_dp_and_pad (
615
+ self , num_tokens : int , with_prefill : bool , enable_dbo : bool
578
616
) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
579
617
if self .dp_size == 1 :
580
- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
618
+ return num_tokens , None , with_prefill , enable_dbo
581
619
582
- num_tokens_across_dp = [0 ] * self .dp_size * 2
583
- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
584
- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
585
- forward_metadata = torch .tensor (num_tokens_across_dp +
586
- [with_prefill , not enable_dbo ],
587
- device = "cpu" ,
588
- dtype = torch .int32 )
589
- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
590
- with_prefill = bool (forward_metadata [- 2 ])
591
-
592
- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
593
- if with_prefill :
594
- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
595
- 2 ]
596
- maybe_padded_num_tokens = num_tokens
597
- else :
598
- num_tokens_across_dp = forward_metadata [:self .dp_size ]
620
+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
621
+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
622
+ device = "cpu" ,
623
+ dtype = torch .int32 )
624
+ return num_tokens , num_tokens_across_dp , True , enable_dbo
599
625
600
- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
601
- # `max_tokens_across_dp`, in other situation it is not necessary.
602
- if self . torchair_graph_enabled and not with_prefill :
603
- maybe_padded_num_tokens = torch . max ( num_tokens_across_dp ). item ()
604
- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
626
+ if self . is_kv_consumer and self . torchair_graph_enabled and len (
627
+ self . torchair_graph_batch_sizes
628
+ ) == 1 and not self . in_profile_run :
629
+ max_num_decode_tokens = self . torchair_graph_batch_sizes [ 0 ]
630
+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
605
631
self .dp_size ,
606
632
device = "cpu" ,
607
633
dtype = torch .int32 )
634
+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
635
+
636
+ maybe_padded_num_tokens = num_tokens
637
+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (
638
+ num_tokens , with_prefill , enable_dbo )
608
639
609
- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
610
- forward_metadata [- 1 ])
640
+ if self .torchair_graph_enabled and not with_prefill :
641
+ max_num_token = num_tokens_across_dp .max ().item ()
642
+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
643
+ max_num_token )
644
+ num_tokens_across_dp = torch .full ((self .dp_size , ),
645
+ maybe_padded_num_tokens ,
646
+ dtype = torch .int32 ,
647
+ device = "cpu" )
648
+
649
+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
611
650
612
651
def _check_dbo_is_valid (self , query_lens : torch .Tensor ,
613
652
attn_state : AscendAttentionState ,
@@ -1108,16 +1147,13 @@ def _process_reqs(
1108
1147
attn_state ,
1109
1148
total_num_scheduled_tokens )
1110
1149
1111
- maybe_padded_num_tokens = total_num_scheduled_tokens
1112
- if self .torchair_graph_enabled and not with_prefill :
1113
- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1114
- total_num_scheduled_tokens )
1150
+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1151
+ attn_state ,
1152
+ total_num_scheduled_tokens )
1115
1153
(padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1116
- enable_dbo ) = self ._get_forward_metadata_across_dp (
1117
- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill ,
1118
- enable_dbo )
1154
+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1155
+ total_num_scheduled_tokens , with_prefill , enable_dbo )
1119
1156
extra_builder_kwargs ['enable_dbo_across_dp' ] = enable_dbo
1120
-
1121
1157
if self .torchair_graph_enabled and not with_prefill :
1122
1158
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
1123
1159
@@ -1791,16 +1827,10 @@ def _dummy_run(
1791
1827
with_prefill : bool = False ,
1792
1828
is_torchair_compile : bool = False ,
1793
1829
) -> torch .Tensor :
1794
- maybe_padded_num_tokens = num_tokens
1795
- if self .torchair_graph_enabled and not with_prefill :
1796
- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1797
- num_tokens )
1798
-
1799
1830
# Padding for DP
1800
1831
(num_tokens , num_tokens_across_dp , with_prefill ,
1801
- _ ) = self ._get_forward_metadata_across_dp (maybe_padded_num_tokens ,
1802
- num_tokens , with_prefill ,
1803
- False )
1832
+ _ ) = self ._get_forward_metadata_across_dp_and_pad (
1833
+ num_tokens , with_prefill , False )
1804
1834
1805
1835
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1806
1836
# for dummy run with LoRA so that the num_reqs collectively
0 commit comments