Skip to content

Commit c8b0f5f

Browse files
authored
[4/N][Refactor] torchair model runner refactor (#2208)
There is lot of torchair code in model runner leading the code hard for maintenance. We'll create new torchair_model_runner to split torchair related logic. Following the workflow #2203, this is the first PR. What's this PR do: create common function `_convert_torch_foramt` for initialize_kv_cache - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@14a5d90 Signed-off-by: wangxiyuan <[email protected]>
1 parent eb43a47 commit c8b0f5f

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from typing import Optional
2121

2222
import torch
23+
import torch_npu
2324
from vllm.config import VllmConfig
2425
from vllm.forward_context import get_forward_context
2526

26-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
27+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2728
maybe_converting_weight_acl_format)
2829
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2930

@@ -113,3 +114,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
113114
with_prefill, is_torchair_compile, input_ids, positions,
114115
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
115116
return hidden_states
117+
118+
def _convert_torch_format(self, kv_cache):
119+
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
120+
return kv_cache

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@
110110

111111
if is_310p():
112112
torch_npu.npu.set_compile_mode(jit_compile=False)
113+
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
114+
else:
115+
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
113116

114117

115118
@dataclass
@@ -2047,8 +2050,8 @@ def load_model(self) -> None:
20472050
if isinstance(module,
20482051
(MergedColumnParallelLinear,
20492052
QKVParallelLinear, RowParallelLinear)):
2050-
module.weight.data = torch_npu.npu_format_cast(
2051-
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
2053+
module.weight.data = self._convert_torch_format(
2054+
module.weight.data)
20522055
if self.drafter:
20532056
logger.info("Loading drafter model...")
20542057
if isinstance(self.drafter, EagleProposer):
@@ -2133,6 +2136,10 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
21332136
ge_cache=False)
21342137
return self.torchair_compiled_models[batch_size]
21352138

2139+
def _convert_torch_format(self, tensor):
2140+
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
2141+
return tensor
2142+
21362143
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21372144
"""
21382145
Initialize KV cache based on `kv_cache_config`.
@@ -2141,9 +2148,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21412148
cache size of each layer
21422149
"""
21432150
self.kv_cache_config = kv_cache_config
2144-
import torch_npu
2145-
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
2146-
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
21472151
kv_caches: Dict[str, torch.Tensor] = {}
21482152

21492153
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
@@ -2202,7 +2206,6 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22022206
kv_cache_spec.head_size)
22032207
dtype = kv_cache_spec.dtype
22042208
if self.model_config.is_deepseek_mla:
2205-
22062209
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
22072210
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
22082211
nope_dim = head_size - rope_dim
@@ -2218,10 +2221,8 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22182221
nope_cache = torch.zeros(nope_cache_shape,
22192222
dtype=dtype,
22202223
device=self.device)
2221-
rope_cache = torch_npu.npu_format_cast(
2222-
rope_cache, acl_format)
2223-
nope_cache = torch_npu.npu_format_cast(
2224-
nope_cache, acl_format)
2224+
rope_cache = self._convert_torch_format(rope_cache)
2225+
nope_cache = self._convert_torch_format(nope_cache)
22252226
else:
22262227

22272228
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
@@ -2259,8 +2260,7 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22592260
kv_cache = torch.zeros(cache_shape,
22602261
dtype=dtype,
22612262
device=self.device)
2262-
kv_cache = torch_npu.npu_format_cast(
2263-
kv_cache, acl_format)
2263+
kv_cache = self._convert_torch_format(kv_cache)
22642264
else:
22652265
cache_size = math.prod(cache_shape)
22662266
cache_size_aligned = cache_size + alignment

0 commit comments

Comments
 (0)