diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c0e175c246a..5dcac6c37ee 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -652,7 +652,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Function `get_and_maybe_dequant_weights` will cast the weights to # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. - if is_enable_nz(): + if is_enable_nz(self.kv_b_proj.weight.data.dtype): self.kv_b_proj.weight.data = torch_npu.npu_format_cast( self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 35ac58d0a9d..ec39b9648ca 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -284,7 +284,7 @@ def pad_qkv_weight(self, data): dim=2) qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - if is_enable_nz(): + if is_enable_nz(qkv_weight_final.dtype): qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( qkv_weight_final) qkv_weight_final_copy = torch_npu.npu_format_cast( @@ -300,7 +300,7 @@ def pad_proj_weight(self, data): (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( self.hidden_size, -1) - if is_enable_nz(): + if is_enable_nz(out_weight.dtype): out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) out_weight_copy = torch_npu.npu_format_cast( out_weight_copy, ACL_FORMAT_FRACTAL_ND) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index ccd461613b5..bd4828351d1 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -268,7 +268,7 @@ def pad_qkv_weight(self, data): dim=2) qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - if is_enable_nz(): + if is_enable_nz(qkv_weight_final.dtype): qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( qkv_weight_final) qkv_weight_final_copy = torch_npu.npu_format_cast( @@ -284,7 +284,7 @@ def pad_proj_weight(self, data): (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( self.hidden_size, -1) - if is_enable_nz(): + if is_enable_nz(out_weight.dtype): out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) out_weight_copy = torch_npu.npu_format_cast( out_weight_copy, ACL_FORMAT_FRACTAL_ND) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index cae6b89b2b2..cc2a377c1a2 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -76,7 +76,7 @@ def process_weights_after_loading(self, layer): w2_data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if not is_310p() and is_enable_nz(): + if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index eab312d5cf8..69889b700ee 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if (is_enable_nz() and layer.weight.data.dtype - in [torch.float16, torch.bfloat16]): + if (is_enable_nz(layer.weight.data.dtype)): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index bc0a8d35783..aeb281b7994 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -835,7 +835,7 @@ def process_weights_after_loading(self, layer): if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() # cast quantized weight tensors in NZ format (29) for higher inference speed - if is_enable_nz(): + if is_enable_nz(layer.weight.data.dtype): layer.weight.data = torch_npu.npu_format_cast( layer.weight.data, 29) layer.weight_scale.data = layer.weight_scale.data.flatten() diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 36c32247d6b..751cd2f935c 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -842,7 +842,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): + if is_enable_nz(wd_qkv.dtype): self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() @@ -876,7 +876,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): + if is_enable_nz(wu_q.dtype): self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data.clone() diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 52a88ecabdf..a8c8e324602 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -71,13 +71,15 @@ def is_310p(): return _IS_310P -def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool: +def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, vllm_config: Optional[VllmConfig] = None) -> bool: global _ENABLE_NZ if _ENABLE_NZ is None: if not vllm_config: raise ValueError( "vllm_config must be provided when _ENABLE_NZ is None") _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" + if dtype in [torch.float16, torch.bfloat16]: + _ENABLE_NZ = 0 return _ENABLE_NZ