Skip to content

Commit 1a70564

Browse files
authored
[5/N][Refactor] torchair model runner refactor (#2216)
There is lot of torchair code in model runner leading the code hard for maintenance. We'll create new torchair_model_runner to split torchair related logic. Following the workflow #2203 What's this PR do: create common function `_capture_model` for capture_model - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@1891a26 Signed-off-by: wangxiyuan <[email protected]>
1 parent 49ec6c9 commit 1a70564

File tree

2 files changed

+66
-57
lines changed

2 files changed

+66
-57
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
import torch_npu
2424
from vllm.config import VllmConfig
2525
from vllm.forward_context import get_forward_context
26+
from vllm.logger import logger
2627

28+
from vllm_ascend.platform import NPUPlatform
29+
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
30+
write_kv_cache_bytes_to_file)
2731
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2832
maybe_converting_weight_acl_format)
2933
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -37,6 +41,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
3741
def _get_forward_metadata_across_dp_and_pad(
3842
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
3943
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
44+
"""Override from NPUModelRunner to pad num_tokens"""
4045
if self.dp_size == 1:
4146
if not with_prefill:
4247
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
@@ -118,3 +123,49 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
118123
def _convert_torch_format(self, kv_cache):
119124
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
120125
return kv_cache
126+
127+
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
128+
# Trigger torchair graph capture for specific shapes.
129+
# Capture the large shapes first so that the smaller shapes
130+
# can reuse the memory pool allocated for the large shapes.
131+
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
132+
for _ in range(self.vllm_config.compilation_config.
133+
cudagraph_num_of_warmups):
134+
self._dummy_run(num_tokens, is_torchair_compile=True)
135+
self._dummy_run(num_tokens, is_torchair_compile=True)
136+
logger.info("Batchsize %d is compiled successfully: %d/%d.",
137+
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
138+
139+
def _capture_model(self):
140+
"""Override from NPUModelRunner to use torchair graph capture."""
141+
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
142+
# torchair graph capture can cause some issues, so now we just
143+
# temporarily split the codepath for the two different graph patterns.
144+
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
145+
graph_num = len(torchair_graph_batch_sizes)
146+
147+
if self.use_cached_npu_graph and not check_torchair_cache_exist():
148+
# If caching is enabled but does not exist, we will compile the model twice. The first
149+
# time is used to generate the cache, and the second time is used to load the cache to
150+
# skip the overhead caused by Dynamo guard mechanism.
151+
logger.info(
152+
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
153+
0.5 * graph_num, 1.5 * graph_num)
154+
self._compile_torchair_graph(torchair_graph_batch_sizes)
155+
NPUPlatform.synchronize()
156+
torch._dynamo.reset()
157+
self.torchair_compiled_models.clear()
158+
if self.use_cached_npu_graph:
159+
logger.info(
160+
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
161+
0.3 * graph_num, 0.5 * graph_num)
162+
self._compile_torchair_graph(torchair_graph_batch_sizes)
163+
else:
164+
logger.info(
165+
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
166+
0.5 * graph_num, 1.5 * graph_num)
167+
self._compile_torchair_graph(torchair_graph_batch_sizes)
168+
169+
if self.new_kv_cache_bytes > 0:
170+
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
171+
self.new_kv_cache_bytes)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8383
from vllm_ascend.platform import NPUPlatform
8484
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
85-
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
86-
write_kv_cache_bytes_to_file)
8785
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
8886
ProfileExecuteDuration, is_310p,
8987
maybe_converting_weight_acl_format,
@@ -2323,67 +2321,27 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
23232321

23242322
return kv_cache_spec
23252323

2326-
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
2327-
# Trigger torchair graph capture for specific shapes.
2324+
def _capture_model(self):
2325+
if not self.use_aclgraph:
2326+
logger.info("Skipping NPU graph capture for eager mode.")
2327+
return
2328+
# Trigger ACL graph capture for specific shapes.
23282329
# Capture the large shapes first so that the smaller shapes
23292330
# can reuse the memory pool allocated for the large shapes.
2330-
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
2331-
for _ in range(self.vllm_config.compilation_config.
2332-
cudagraph_num_of_warmups):
2333-
self._dummy_run(num_tokens, is_torchair_compile=True)
2334-
self._dummy_run(num_tokens, is_torchair_compile=True)
2335-
logger.info("Batchsize %d is compiled successfully: %d/%d.",
2336-
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
2331+
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
2332+
with graph_capture(device=self.device):
2333+
for num_tokens in reversed(self.aclgraph_batch_sizes):
2334+
for _ in range(self.vllm_config.compilation_config.
2335+
cudagraph_num_of_warmups):
2336+
self._dummy_run(num_tokens)
2337+
self._dummy_run(num_tokens)
23372338

23382339
def capture_model(self) -> None:
23392340
start_time = time.perf_counter()
23402341
start_free_npu_memory = torch.npu.mem_get_info()[0]
2341-
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
2342-
# torchair graph capture can cause some issues, so now we just
2343-
# temporarily split the codepath for the two different graph patterns.
2344-
if self.torchair_graph_enabled:
2345-
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
2346-
graph_num = len(torchair_graph_batch_sizes)
2347-
2348-
if self.use_cached_npu_graph and not check_torchair_cache_exist():
2349-
# If caching is enabled but does not exist, we will compile the model twice. The first
2350-
# time is used to generate the cache, and the second time is used to load the cache to
2351-
# skip the overhead caused by Dynamo guard mechanism.
2352-
logger.info(
2353-
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
2354-
0.5 * graph_num, 1.5 * graph_num)
2355-
self._compile_torchair_graph(torchair_graph_batch_sizes)
2356-
NPUPlatform.synchronize()
2357-
torch._dynamo.reset()
2358-
self.torchair_compiled_models.clear()
2359-
if self.use_cached_npu_graph:
2360-
logger.info(
2361-
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
2362-
0.3 * graph_num, 0.5 * graph_num)
2363-
self._compile_torchair_graph(torchair_graph_batch_sizes)
2364-
else:
2365-
logger.info(
2366-
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
2367-
0.5 * graph_num, 1.5 * graph_num)
2368-
self._compile_torchair_graph(torchair_graph_batch_sizes)
2369-
2370-
if self.new_kv_cache_bytes > 0:
2371-
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
2372-
self.new_kv_cache_bytes)
2373-
elif self.use_aclgraph:
2374-
# Trigger ACL graph capture for specific shapes.
2375-
# Capture the large shapes first so that the smaller shapes
2376-
# can reuse the memory pool allocated for the large shapes.
2377-
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
2378-
with graph_capture(device=self.device):
2379-
for num_tokens in reversed(self.aclgraph_batch_sizes):
2380-
for _ in range(self.vllm_config.compilation_config.
2381-
cudagraph_num_of_warmups):
2382-
self._dummy_run(num_tokens)
2383-
self._dummy_run(num_tokens)
2384-
else:
2385-
logger.info("Skipping NPU graph capture for eager mode.")
2386-
return
2342+
2343+
self._capture_model()
2344+
23872345
end_time = time.perf_counter()
23882346
end_free_npu_memory = torch.npu.mem_get_info()[0]
23892347
elapsed_time = end_time - start_time

0 commit comments

Comments
 (0)