110
110
111
111
if is_310p ():
112
112
torch_npu .npu .set_compile_mode (jit_compile = False )
113
+ ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
114
+ else :
115
+ ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
113
116
114
117
115
118
@dataclass
@@ -2047,8 +2050,8 @@ def load_model(self) -> None:
2047
2050
if isinstance (module ,
2048
2051
(MergedColumnParallelLinear ,
2049
2052
QKVParallelLinear , RowParallelLinear )):
2050
- module .weight .data = torch_npu . npu_format_cast (
2051
- module .weight .data , ACL_FORMAT_FRACTAL_NZ )
2053
+ module .weight .data = self . _convert_torch_format (
2054
+ module .weight .data )
2052
2055
if self .drafter :
2053
2056
logger .info ("Loading drafter model..." )
2054
2057
if isinstance (self .drafter , EagleProposer ):
@@ -2133,6 +2136,10 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
2133
2136
ge_cache = False )
2134
2137
return self .torchair_compiled_models [batch_size ]
2135
2138
2139
+ def _convert_torch_format (self , tensor ):
2140
+ tensor = torch_npu .npu_format_cast (tensor , ACL_FORMAT )
2141
+ return tensor
2142
+
2136
2143
def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
2137
2144
"""
2138
2145
Initialize KV cache based on `kv_cache_config`.
@@ -2141,9 +2148,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
2141
2148
cache size of each layer
2142
2149
"""
2143
2150
self .kv_cache_config = kv_cache_config
2144
- import torch_npu
2145
- acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p (
2146
- ) and not self .torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
2147
2151
kv_caches : Dict [str , torch .Tensor ] = {}
2148
2152
2149
2153
def align_memory (tensor : torch .Tensor , alignment : int ) -> torch .Tensor :
@@ -2202,7 +2206,6 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
2202
2206
kv_cache_spec .head_size )
2203
2207
dtype = kv_cache_spec .dtype
2204
2208
if self .model_config .is_deepseek_mla :
2205
-
2206
2209
num_blocks , block_size , num_kv_heads , head_size = kv_cache_shape
2207
2210
rope_dim = self .model_config .hf_text_config .qk_rope_head_dim
2208
2211
nope_dim = head_size - rope_dim
@@ -2218,10 +2221,8 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
2218
2221
nope_cache = torch .zeros (nope_cache_shape ,
2219
2222
dtype = dtype ,
2220
2223
device = self .device )
2221
- rope_cache = torch_npu .npu_format_cast (
2222
- rope_cache , acl_format )
2223
- nope_cache = torch_npu .npu_format_cast (
2224
- nope_cache , acl_format )
2224
+ rope_cache = self ._convert_torch_format (rope_cache )
2225
+ nope_cache = self ._convert_torch_format (nope_cache )
2225
2226
else :
2226
2227
2227
2228
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
@@ -2259,8 +2260,7 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
2259
2260
kv_cache = torch .zeros (cache_shape ,
2260
2261
dtype = dtype ,
2261
2262
device = self .device )
2262
- kv_cache = torch_npu .npu_format_cast (
2263
- kv_cache , acl_format )
2263
+ kv_cache = self ._convert_torch_format (kv_cache )
2264
2264
else :
2265
2265
cache_size = math .prod (cache_shape )
2266
2266
cache_size_aligned = cache_size + alignment
0 commit comments