Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):

# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
if is_enable_nz(self.kv_b_proj.weight.data.dtype):
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)

Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def pad_qkv_weight(self, data):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)

if is_enable_nz():
if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
Expand All @@ -300,7 +300,7 @@ def pad_proj_weight(self, data):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)

if is_enable_nz():
if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def pad_qkv_weight(self, data):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)

if is_enable_nz():
if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
Expand All @@ -284,7 +284,7 @@ def pad_proj_weight(self, data):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)

if is_enable_nz():
if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process_weights_after_loading(self, layer):
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)

if not is_310p() and is_enable_nz():
if not is_310p() and is_enable_nz(layer.w13_weight.data.type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

此处调用 is_enable_nz 时使用了 layer.w13_weight.data.type,这似乎是一个错误。torch.Tensor.type 是一个返回类型字符串(例如 'torch.npu.FloatTensor')的方法,而不是 torch.dtype 对象。is_enable_nz 函数期望接收一个 torch.dtype 对象。

你应该使用 .dtype 属性来获取张量的数据类型。

Suggested change
if not is_310p() and is_enable_nz(layer.w13_weight.data.type):
if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype):

layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if (is_enable_nz() and layer.weight.data.dtype
in [torch.float16, torch.bfloat16]):
if (is_enable_nz(layer.weight.data.dtype)):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format (29) for higher inference speed
if is_enable_nz():
if is_enable_nz(layer.weight.data.dtype):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, 29)
layer.weight_scale.data = layer.weight_scale.data.flatten()
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/torchair/torchair_sfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wd_qkv.dtype):
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)

kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
Expand Down Expand Up @@ -876,7 +876,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wu_q.dtype):
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)

qb_deq_scl = self.q_proj.deq_scale.data.clone()
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def is_310p():
return _IS_310P


def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, vllm_config: Optional[VllmConfig] = None) -> bool:
global _ENABLE_NZ
if _ENABLE_NZ is None:
if not vllm_config:
raise ValueError(
"vllm_config must be provided when _ENABLE_NZ is None")
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
if dtype in [torch.float16, torch.bfloat16]:
_ENABLE_NZ = 0
Comment on lines +81 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

is_enable_nz 函数中对全局变量 _ENABLE_NZ 的修改会产生意外的副作用。当使用浮点类型(torch.float16torch.bfloat16)调用此函数时,_ENABLE_NZ 会被永久设置为 0。这将导致后续所有对 is_enable_nz 的调用(无论数据类型如何)都返回 0(或 False),从而禁用了 NZ 格式转换。

为了避免这种副作用,建议不要修改全局变量,而是根据 dtype 直接返回适当的值。

Suggested change
if dtype in [torch.float16, torch.bfloat16]:
_ENABLE_NZ = 0
if dtype in [torch.float16, torch.bfloat16]:
return 0

return _ENABLE_NZ


Expand Down
Loading