Skip to content

Commit 583ad8f

Browse files
authored
[main][refractor] Refractor forward metadata retrieval across DP nodes to reduce redundant padding. (#2062)
Before refactoring cross-DP decoding metadata aggregation, clean up the token‐padding logic . ### What this PR does: 1. First checks whether any DP instance is in the prefill phase. 2. If in the `decode` phase and `torchair_graph_enabled `is true, pads each DP instance’s token count up to the global maximum. 3. If in the `prefill` phase, or in decode phase with graph mode **disabled**, returns each DP instance’s original token count without padding. This reordering removes the previous two‐step padding/unpadding flow and ensures padding only occurs when strictly necessary. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@bd3db7f Signed-off-by: yx0716 <[email protected]> Signed-off-by: MengqingCao <[email protected]>
1 parent 27c2b5c commit 583ad8f

File tree

1 file changed

+76
-46
lines changed

1 file changed

+76
-46
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,38 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
348348
torch._logging.set_logs(
349349
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
350350

351+
self.check_batch_sizes_consistency()
351352
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
352353
self.in_profile_run = False
353354

354355
# kv role
355356
self.is_kv_producer = False
357+
self.is_kv_consumer = False
356358
if vllm_config.kv_transfer_config is not None:
357359
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
360+
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
361+
362+
def check_batch_sizes_consistency(self) -> None:
363+
if not dist.is_initialized():
364+
return
365+
366+
local = torch.tensor(self.torchair_graph_batch_sizes,
367+
device="cpu",
368+
dtype=torch.int32)
369+
gathered_graph_batch_size = local.clone()
370+
dist.all_reduce(gathered_graph_batch_size,
371+
group=get_dp_group().cpu_group)
372+
expected = local * self.dp_size
373+
374+
if not torch.equal(gathered_graph_batch_size, expected):
375+
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
376+
as_tuple=False).flatten().tolist()
377+
raise AssertionError(
378+
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
379+
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
380+
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
381+
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
382+
)
358383

359384
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
360385
"""Update the cached states and the persistent batch with the scheduler
@@ -570,44 +595,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570595
self.input_batch.refresh_sampling_metadata()
571596

572597
def _get_forward_metadata_across_dp(
573-
self,
574-
maybe_padded_num_tokens: int,
575-
num_tokens: int,
576-
with_prefill: bool,
577-
enable_dbo: bool = False,
598+
self, num_tokens: int, with_prefill: bool,
599+
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
600+
601+
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
602+
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
603+
dtype=torch.int32,
604+
device="cpu")
605+
num_tokens_across_dp[self.dp_rank] = num_tokens
606+
num_tokens_across_dp[-2] = int(with_prefill)
607+
num_tokens_across_dp[-1] = int(not enable_dbo)
608+
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
609+
with_prefill = bool(num_tokens_across_dp[-2])
610+
enable_dbo = not bool(num_tokens_across_dp[-1])
611+
num_tokens_across_dp = num_tokens_across_dp[:-2]
612+
return num_tokens_across_dp, with_prefill, enable_dbo
613+
614+
def _get_forward_metadata_across_dp_and_pad(
615+
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
578616
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
579617
if self.dp_size == 1:
580-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
618+
return num_tokens, None, with_prefill, enable_dbo
581619

582-
num_tokens_across_dp = [0] * self.dp_size * 2
583-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
584-
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
585-
forward_metadata = torch.tensor(num_tokens_across_dp +
586-
[with_prefill, not enable_dbo],
587-
device="cpu",
588-
dtype=torch.int32)
589-
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
590-
with_prefill = bool(forward_metadata[-2])
591-
592-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
593-
if with_prefill:
594-
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
595-
2]
596-
maybe_padded_num_tokens = num_tokens
597-
else:
598-
num_tokens_across_dp = forward_metadata[:self.dp_size]
620+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
621+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
622+
device="cpu",
623+
dtype=torch.int32)
624+
return num_tokens, num_tokens_across_dp, True, enable_dbo
599625

600-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
601-
# `max_tokens_across_dp`, in other situation it is not necessary.
602-
if self.torchair_graph_enabled and not with_prefill:
603-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
604-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
626+
if self.is_kv_consumer and self.torchair_graph_enabled and len(
627+
self.torchair_graph_batch_sizes
628+
) == 1 and not self.in_profile_run:
629+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
630+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
605631
self.dp_size,
606632
device="cpu",
607633
dtype=torch.int32)
634+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
635+
636+
maybe_padded_num_tokens = num_tokens
637+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
638+
num_tokens, with_prefill, enable_dbo)
608639

609-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
610-
forward_metadata[-1])
640+
if self.torchair_graph_enabled and not with_prefill:
641+
max_num_token = num_tokens_across_dp.max().item()
642+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
643+
max_num_token)
644+
num_tokens_across_dp = torch.full((self.dp_size, ),
645+
maybe_padded_num_tokens,
646+
dtype=torch.int32,
647+
device="cpu")
648+
649+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
611650

612651
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
613652
attn_state: AscendAttentionState,
@@ -1108,16 +1147,13 @@ def _process_reqs(
11081147
attn_state,
11091148
total_num_scheduled_tokens)
11101149

1111-
maybe_padded_num_tokens = total_num_scheduled_tokens
1112-
if self.torchair_graph_enabled and not with_prefill:
1113-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1114-
total_num_scheduled_tokens)
1150+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
1151+
attn_state,
1152+
total_num_scheduled_tokens)
11151153
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
1116-
enable_dbo) = self._get_forward_metadata_across_dp(
1117-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
1118-
enable_dbo)
1154+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1155+
total_num_scheduled_tokens, with_prefill, enable_dbo)
11191156
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
1120-
11211157
if self.torchair_graph_enabled and not with_prefill:
11221158
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
11231159

@@ -1791,16 +1827,10 @@ def _dummy_run(
17911827
with_prefill: bool = False,
17921828
is_torchair_compile: bool = False,
17931829
) -> torch.Tensor:
1794-
maybe_padded_num_tokens = num_tokens
1795-
if self.torchair_graph_enabled and not with_prefill:
1796-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1797-
num_tokens)
1798-
17991830
# Padding for DP
18001831
(num_tokens, num_tokens_across_dp, with_prefill,
1801-
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
1802-
num_tokens, with_prefill,
1803-
False)
1832+
_) = self._get_forward_metadata_across_dp_and_pad(
1833+
num_tokens, with_prefill, False)
18041834

18051835
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
18061836
# for dummy run with LoRA so that the num_reqs collectively

0 commit comments

Comments
 (0)