Skip to content

Commit 881e36d

Browse files
authored
[3/N][Refactor] torchair model runner refactor (#2207)
There is lot of torchair code in model runner leading the code hard for maintenance. We'll create new torchair_model_runner to split torchair related logic. Following the workflow #2203, this is the first PR. What's this PR do: create common function `_build_attention_metadata` and `_generate_dummy_run_hidden_states` for dummy_run - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@ebf7605 Signed-off-by: wangxiyuan <[email protected]>
1 parent 29aaba5 commit 881e36d

File tree

2 files changed

+89
-66
lines changed

2 files changed

+89
-66
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121

2222
import torch
2323
from vllm.config import VllmConfig
24+
from vllm.forward_context import get_forward_context
2425

26+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
27+
maybe_converting_weight_acl_format)
2528
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2629

2730

@@ -55,3 +58,58 @@ def _get_forward_metadata_across_dp_and_pad(
5558
maybe_padded_num_tokens = num_tokens
5659

5760
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
61+
62+
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
63+
# NOTE: If torchair graph mode and not with_prefill,
64+
# we can't skip_attn, it will cause graph recompile.
65+
if not with_prefill:
66+
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
67+
num_reqs=num_reqs, num_actual_tokens=1)
68+
else:
69+
attn_metadata = super()._build_attention_metadata(
70+
with_prefill, num_reqs, skip_attn)
71+
return attn_metadata
72+
73+
def _generate_dummy_run_hidden_states(self, with_prefill,
74+
is_torchair_compile, input_ids,
75+
positions, attn_metadata, num_tokens,
76+
intermediate_tensors, inputs_embeds):
77+
78+
if not with_prefill:
79+
# Only mark static while compiling
80+
if is_torchair_compile:
81+
torch._dynamo.mark_static(input_ids)
82+
torch._dynamo.mark_static(positions)
83+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
84+
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
85+
torch._dynamo.mark_static(get_forward_context().mc2_mask)
86+
if hasattr(attn_metadata.decode, "sin"):
87+
torch._dynamo.mark_static(attn_metadata.decode.sin)
88+
torch._dynamo.mark_static(attn_metadata.decode.cos)
89+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
90+
if self.speculative_config:
91+
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
92+
for kv in self.kv_caches:
93+
assert isinstance(kv, tuple), "kv_cache must be a tuple"
94+
torch._dynamo.mark_static(kv[0])
95+
torch._dynamo.mark_static(kv[1])
96+
97+
maybe_converting_weight_acl_format(self.model,
98+
ACL_FORMAT_FRACTAL_NZ)
99+
100+
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
101+
model_kwargs = {}
102+
model_kwargs["kv_caches"] = self.kv_caches
103+
model_kwargs["attn_metadata"] = attn_metadata
104+
hidden_states = compiled_model(
105+
input_ids=input_ids,
106+
positions=positions,
107+
intermediate_tensors=intermediate_tensors,
108+
inputs_embeds=None,
109+
**model_kwargs,
110+
)
111+
else:
112+
hidden_states = super()._generate_dummy_run_hidden_states(
113+
with_prefill, is_torchair_compile, input_ids, positions,
114+
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
115+
return hidden_states

vllm_ascend/worker/model_runner_v1.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,31 @@ def get_finished_kv_transfer(
18321832
scheduler_output.finished_req_ids)
18331833
return None, None
18341834

1835+
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
1836+
if skip_attn:
1837+
attn_metadata = None
1838+
else:
1839+
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1840+
attn_metadata = None
1841+
return attn_metadata
1842+
1843+
def _generate_dummy_run_hidden_states(self, with_prefill,
1844+
is_torchair_compile, input_ids,
1845+
positions, attn_metadata, num_tokens,
1846+
intermediate_tensors, inputs_embeds):
1847+
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
1848+
hidden_states = self.model(input_ids=input_ids,
1849+
positions=positions,
1850+
intermediate_tensors=intermediate_tensors,
1851+
inputs_embeds=inputs_embeds)
1852+
if self.use_aux_hidden_state_outputs:
1853+
hidden_states, _ = hidden_states
1854+
else:
1855+
hidden_states = hidden_states
1856+
if self.use_spec_decode and isinstance(self.drafter, EagleProposer):
1857+
self.drafter.dummy_run(num_tokens)
1858+
return hidden_states
1859+
18351860
@torch.inference_mode()
18361861
def _dummy_run(
18371862
self,
@@ -1868,20 +1893,11 @@ def _dummy_run(
18681893
if self.is_kv_producer:
18691894
with_prefill = True
18701895

1871-
# NOTE: If torchair graph mode and not with_prefill,
1872-
# we can't skip_attn, it will cause graph recompile.
1873-
if self.torchair_graph_enabled and not with_prefill:
1874-
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
1875-
num_reqs=num_reqs, num_actual_tokens=1)
1876-
elif skip_attn:
1877-
attn_metadata = None
1878-
else:
1879-
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1880-
attn_metadata = None
1896+
attn_metadata = self._build_attention_metadata(with_prefill, num_reqs,
1897+
skip_attn)
18811898

18821899
with self.maybe_dummy_run_with_lora(self.lora_config,
18831900
num_scheduled_tokens):
1884-
model = self.model
18851901
if self.is_multimodal_model:
18861902
input_ids = None
18871903
inputs_embeds = self.inputs_embeds[:num_tokens]
@@ -1917,61 +1933,10 @@ def _dummy_run(
19171933
in_profile_run=self.in_profile_run,
19181934
num_actual_tokens=0,
19191935
):
1920-
model_kwargs = {}
1921-
if self.torchair_graph_enabled and not with_prefill:
1922-
# Only mark static while compiling
1923-
if is_torchair_compile:
1924-
torch._dynamo.mark_static(input_ids)
1925-
torch._dynamo.mark_static(positions)
1926-
torch._dynamo.mark_static(
1927-
attn_metadata.decode.block_table)
1928-
torch._dynamo.mark_static(
1929-
attn_metadata.decode.input_positions)
1930-
torch._dynamo.mark_static(
1931-
get_forward_context().mc2_mask)
1932-
if hasattr(attn_metadata.decode, "sin"):
1933-
torch._dynamo.mark_static(attn_metadata.decode.sin)
1934-
torch._dynamo.mark_static(attn_metadata.decode.cos)
1935-
torch._dynamo.mark_static(attn_metadata.slot_mapping)
1936-
if self.speculative_config:
1937-
torch._dynamo.mark_static(
1938-
attn_metadata.decode.attn_mask)
1939-
for kv in self.kv_caches:
1940-
assert isinstance(
1941-
kv, tuple), "kv_cache must be a tuple"
1942-
torch._dynamo.mark_static(kv[0])
1943-
torch._dynamo.mark_static(kv[1])
1944-
1945-
maybe_converting_weight_acl_format(self.model,
1946-
ACL_FORMAT_FRACTAL_NZ)
1947-
1948-
compiled_model = self._get_torchair_lazy_compiled_model(
1949-
num_tokens)
1950-
model_kwargs["kv_caches"] = self.kv_caches
1951-
model_kwargs["attn_metadata"] = attn_metadata
1952-
hidden_states = compiled_model(
1953-
input_ids=input_ids,
1954-
positions=positions,
1955-
intermediate_tensors=intermediate_tensors,
1956-
inputs_embeds=None,
1957-
**model_kwargs,
1958-
)
1959-
else:
1960-
maybe_converting_weight_acl_format(self.model,
1961-
ACL_FORMAT_FRACTAL_ND)
1962-
1963-
hidden_states = model(
1964-
input_ids=input_ids,
1965-
positions=positions,
1966-
intermediate_tensors=intermediate_tensors,
1967-
inputs_embeds=inputs_embeds)
1968-
if self.use_aux_hidden_state_outputs:
1969-
hidden_states, _ = hidden_states
1970-
else:
1971-
hidden_states = hidden_states
1972-
if self.use_spec_decode and isinstance(
1973-
self.drafter, EagleProposer):
1974-
self.drafter.dummy_run(num_tokens)
1936+
hidden_states = self._generate_dummy_run_hidden_states(
1937+
with_prefill, is_torchair_compile, input_ids, positions,
1938+
attn_metadata, num_tokens, intermediate_tensors,
1939+
inputs_embeds)
19751940
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
19761941
assert isinstance(self.drafter, MtpProposer)
19771942
self.drafter.dummy_run(

0 commit comments

Comments
 (0)