Skip to content

Commit 1ab1541

Browse files
authored
[2/N][Refactor] torchair model runner refactor (#2204)
There is lot of torchair code in model runner leading the code hard for maintenance. We'll create new torchair_model_runner to split torchair related logic. Following the workflow #2203 What's this PR do: move `torchair` related logic into `_get_forward_metadata_across_dp` and override it in torchair model runner - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@1b99028 Signed-off-by: wangxiyuan <[email protected]>
1 parent 9260910 commit 1ab1541

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
1818
#
1919

20+
from typing import Optional
21+
2022
import torch
2123
from vllm.config import VllmConfig
2224

@@ -27,3 +29,29 @@ class NPUTorchairModelRunner(NPUModelRunner):
2729

2830
def __init__(self, vllm_config: VllmConfig, device: torch.device):
2931
super().__init__(vllm_config, device)
32+
33+
def _get_forward_metadata_across_dp_and_pad(
34+
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
35+
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
36+
if self.dp_size == 1:
37+
if not with_prefill:
38+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
39+
num_tokens)
40+
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
41+
return num_tokens, None, with_prefill, enable_dbo
42+
43+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
44+
num_tokens, with_prefill, enable_dbo)
45+
46+
if not with_prefill:
47+
max_num_token = num_tokens_across_dp.max().item()
48+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
49+
max_num_token)
50+
num_tokens_across_dp = torch.full((self.dp_size, ),
51+
maybe_padded_num_tokens,
52+
dtype=torch.int32,
53+
device="cpu")
54+
else:
55+
maybe_padded_num_tokens = num_tokens
56+
57+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -640,26 +640,11 @@ def _get_forward_metadata_across_dp_and_pad(
640640
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
641641
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
642642
if self.dp_size == 1:
643-
if self.torchair_graph_enabled and not with_prefill:
644-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
645-
num_tokens)
646-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
647643
return num_tokens, None, with_prefill, enable_dbo
648644

649-
maybe_padded_num_tokens = num_tokens
650645
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
651646
num_tokens, with_prefill, enable_dbo)
652-
653-
if self.torchair_graph_enabled and not with_prefill:
654-
max_num_token = num_tokens_across_dp.max().item()
655-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
656-
max_num_token)
657-
num_tokens_across_dp = torch.full((self.dp_size, ),
658-
maybe_padded_num_tokens,
659-
dtype=torch.int32,
660-
device="cpu")
661-
662-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
647+
return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
663648

664649
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
665650
attn_state: AscendAttentionState,

0 commit comments

Comments
 (0)