Skip to content

Commit 19c7b3e

Browse files
authored
[v0.9.1][bugfix] fix torchair runtime errror caused by configuration mismtaches and .kv_cache_bytes file missing (#2312)
### What this PR does / why we need it? Original implementation of torchair caching forces users to make everything prepared, fix all the configuration and enable `use_cached_npu_graph`, and it might cause some problems confusing to understand and tackle for users. It is better to compile the graph twice instead of reusing the old kvcaches and cached torchair graph. And the extra duration time is acceptable. ### Does this PR introduce _any_ user-facing change? If users want to enabling torchair.cache_compile with high compilation speed, it is recommended to enable both `use_cached_kv_cache_bytes` and `use_cached_graph` in `torchair_graph_config`. Without `use_cached_kv_cache_bytes`, we'll compile torchair computation graph twice to avoid runtime error caused by configuration mismtaches (the second compilation will be much faster). ### How was this patch tested? CI and e2e vllm serving passed. Signed-off-by: linfeng-yuan <[email protected]>
1 parent 9dc23b6 commit 19c7b3e

File tree

4 files changed

+41
-10
lines changed

4 files changed

+41
-10
lines changed

vllm_ascend/ascend_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __init__(self, torchair_graph_config):
7575
) # Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode
7676
self.use_cached_graph = torchair_graph_config.get(
7777
"use_cached_graph", False) # Whether to use cached graph
78+
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
79+
"use_cached_kv_cache_bytes", False
80+
) # Whether to use cached kv_caches' memory, this option can only be enabled with use_cached_graph
7881
self.graph_batch_sizes = torchair_graph_config.get(
7982
"graph_batch_sizes", []) # The batch size for torchair graph cache
8083
self.graph_batch_sizes_init = torchair_graph_config.get(
@@ -106,6 +109,10 @@ def __init__(self, torchair_graph_config):
106109
raise RuntimeError(
107110
"use_cached_graph is valid only when Torchair graph mode is enabled"
108111
)
112+
if self.use_cached_kv_cache_bytes:
113+
raise RuntimeError(
114+
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
115+
)
109116
if self.graph_batch_sizes:
110117
raise RuntimeError(
111118
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
@@ -133,8 +140,12 @@ def __init__(self, torchair_graph_config):
133140
if not self.enable_multistream_moe:
134141
if self.enable_super_kernel:
135142
raise RuntimeError(
136-
"enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled"
143+
"enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe are enabled"
137144
)
145+
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
146+
raise RuntimeError(
147+
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
148+
)
138149

139150

140151
class AscendSchedulerConfig:

vllm_ascend/platform.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from vllm.platforms import Platform, PlatformEnum
2929

3030
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
31-
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
31+
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD,
32+
check_torchair_cache_exist,
33+
delete_torchair_cache_file,
34+
update_aclgraph_sizes)
3235

3336
if TYPE_CHECKING:
3437
from vllm.config import ModelConfig, VllmConfig
@@ -157,6 +160,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
157160
"Torchair compilation enabled on NPU. Setting level to NO_COMPILATION"
158161
)
159162
compilation_config.level = CompilationLevel.NO_COMPILATION
163+
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
164+
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
165+
# this will increase graph compilation duration, it significantly enhances robustness and decreases
166+
# graph launching time during inference. In order to decrease torchair graph compilation time, users
167+
# can enable both `use_cached_graph` and `use_cached_kv_cache_bytes` in torchair_graph_config.
168+
if check_torchair_cache_exist(
169+
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
170+
delete_torchair_cache_file()
160171
elif parallel_config.distributed_executor_backend == "ray":
161172
logger.warning(
162173
"Ray distributed executor backend is not compatible with ACL Graph mode "

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8585
from vllm_ascend.platform import NPUPlatform
8686
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
87-
from vllm_ascend.utils import (ProfileExecuteDuration,
87+
from vllm_ascend.utils import (TORCHAIR_CACHE_DIR, ProfileExecuteDuration,
8888
check_torchair_cache_exist,
8989
write_kv_cache_bytes_to_file)
9090
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
@@ -360,6 +360,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
360360
ascend_config = get_ascend_config()
361361
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
362362
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
363+
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
363364
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
364365

365366
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
@@ -1904,6 +1905,7 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
19041905
self.model.__dict__[forward_proxy_name],
19051906
dynamic=True,
19061907
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1908+
cache_dir=TORCHAIR_CACHE_DIR,
19071909
config=config,
19081910
ge_cache=False)
19091911
return self.torchair_compiled_models[batch_size]
@@ -2082,14 +2084,20 @@ def capture_model(self) -> None:
20822084
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
20832085
graph_num = len(torchair_graph_batch_sizes)
20842086
if self.use_cached_npu_graph and not check_torchair_cache_exist():
2085-
# If caching is enabled but does not exist, we will compile the model twice. The first
2086-
# time is used to generate the cache, and the second time is used to load the cache to
2087-
# skip the overhead caused by Dynamo guard mechanism.
2087+
# If caching is enabled but does not exist (either
2088+
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
2089+
# different), we will compile the model twice. The first time is
2090+
# used to generate the cache, and the second time is used to load the
2091+
# cache to skip the overhead caused by Dynamo guard mechanism.
20882092
logger.info(
2089-
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
2093+
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
2094+
" torchair cache, this usually takes %.1f~%.1f mins.",
20902095
0.5 * graph_num, 1.5 * graph_num)
20912096
self._compile_torchair_graph(torchair_graph_batch_sizes)
20922097
NPUPlatform.synchronize()
2098+
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
2099+
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
2100+
# runtime errors caused by configuration mismatches in graph mode.
20932101
torch._dynamo.reset()
20942102
self.torchair_compiled_models.clear()
20952103
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
@@ -2104,8 +2112,7 @@ def capture_model(self) -> None:
21042112
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
21052113
0.5 * graph_num, 1.5 * graph_num)
21062114
self._compile_torchair_graph(torchair_graph_batch_sizes)
2107-
2108-
if self.new_kv_cache_bytes > 0:
2115+
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
21092116
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
21102117
self.new_kv_cache_bytes)
21112118
elif self.use_aclgraph:

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def determine_available_memory(self) -> int:
193193
logger.info(
194194
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
195195
)
196-
if get_ascend_config().torchair_graph_config.enabled:
196+
if (get_ascend_config().torchair_graph_config.enabled
197+
and get_ascend_config(
198+
).torchair_graph_config.use_cached_kv_cache_bytes):
197199
if check_torchair_cache_exist(
198200
) and check_kv_cache_bytes_cache_exist():
199201
old_kv_cache_bytes = read_kv_cache_bytes_from_file(

0 commit comments

Comments
 (0)