diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 1254f3a2ff..6ebf7be0ea 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -183,7 +183,8 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W8A8 - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1 diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index f3348d8d51..a4e8db85a7 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -30,11 +30,20 @@ from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" QWEN_DENSE_MODELS = [ "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" ] +QWEN_W4A8_OLD_VERSION_MODELS = [ + "vllm-ascend/Qwen3-8B-W4A8", +] + +QWEN_W4A8_NEW_VERSION_MODELS = [ + "Anionex/Qwen3-1.7B-W4A8-V1", # TODO: move it into vllm-ascend ci modelscope repo +] + DEEPSEEK_W4A8_MODELS = [ "vllm-ascend/DeepSeek-V3-W4A8-Pruing", "vllm-ascend/DeepSeek-V3.1-W4A8-puring" @@ -98,20 +107,36 @@ def test_models_distributed_Qwen3_W8A8(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_models_distributed_Qwen3_W4A8DYNAMIC(): - example_prompts = [ +@pytest.mark.parametrize("model", QWEN_W4A8_OLD_VERSION_MODELS) +def test_models_distributed_Qwen3_W4A8DYNAMIC_old_version(model): + prompts = [ "Hello, my name is", ] max_tokens = 5 + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(prompts, max_tokens) + +@pytest.mark.parametrize("model", QWEN_W4A8_NEW_VERSION_MODELS) +def test_models_distributed_Qwen3_W4A8DYNAMIC_new_version(model): + prompts = [ + "Hello, my name is", + ] + max_tokens = 5 with VllmRunner( - snapshot_download("vllm-ascend/Qwen3-8B-W4A8"), + snapshot_download(model), max_model_len=8192, dtype="auto", tensor_parallel_size=2, quantization="ascend", ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_model.generate_greedy(prompts, max_tokens) @pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index a14702b7a0..2116b0c168 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -9,25 +9,31 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): - def setUp(self): - with patch( - 'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config' - ) as mock_get_current_vllm_config: - mock_vllm_config = Mock() - mock_vllm_config.quant_config = Mock( - quant_description={"group_size": 256}) - mock_vllm_config.scheduler_config = Mock( - max_num_batched_tokens=2048, - max_model_len=2048, - enable_chunked_prefill=False) - mock_get_current_vllm_config.return_value = mock_vllm_config - self.method = AscendW4A8DynamicLinearMethod() - self.method.group_size = 8 + @patch('vllm.distributed.get_tensor_model_parallel_world_size') + @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') + def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): + mock_get_tp_world_size.return_value = 1 + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) + mock_get_current_vllm_config.return_value = mock_vllm_config + self.method = AscendW4A8DynamicLinearMethod() + self.method.group_size = 8 def test_get_weight(self): weight = self.method.get_weight(8, 32, torch.bfloat16) self.assertEqual(weight["weight"].dtype, torch.int8) self.assertEqual(weight["weight"].shape, (32, 8)) + # new quant version weight + self.method.new_quant_version = True + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (16, 8)) + self.assertEqual(weight["_packed_dim"], 0) + self.assertEqual(weight["_packed_factor"], 2) def test_get_pergroup_param(self): params = self.method.get_pergroup_param(8, 32, torch.bfloat16) @@ -39,6 +45,75 @@ def test_get_pergroup_param(self): self.assertEqual(params["weight_scale_second"].shape, (32, 1)) self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + # new quant version weight + self.method.new_quant_version = True + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="column") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 1)) + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="row") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 16)) + + @patch('torch_npu.npu_convert_weight_to_int4pack') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, + mock_npu_convert_weight): + mock_npu.side_effect = lambda: torch.zeros( + (1, 32), dtype=torch.float32) + mock_npu_convert_weight.return_value = torch.zeros((32, 4), + dtype=torch.int32) + # old quant version weight + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter(torch.zeros((32, 8), + dtype=torch.int8), + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset = torch.nn.Parameter(torch.empty_like( + layer.weight_scale.data), + requires_grad=False) + layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( + layer.weight_scale_second.data), + requires_grad=False) + self.method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "weight_scale_bias")) + self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) + self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.method.new_quant_version = True + new_layer = torch.nn.Module() + new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8), + dtype=torch.int8), + requires_grad=False) + new_layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset = torch.nn.Parameter(torch.empty_like( + new_layer.weight_scale.data), + requires_grad=False) + new_layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset_second = torch.nn.Parameter( + torch.empty_like(new_layer.weight_scale_second.data), + requires_grad=False) + new_layer.scale_bias = torch.nn.Parameter(torch.zeros( + (32, 1), dtype=torch.float32), + requires_grad=False) + self.method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.scale_bias.data.shape, (32, )) + self.assertTrue(hasattr(new_layer, "weight_scale_second")) + self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32)) class TestAscendW4A8DynamicFusedMoEMethod(TestBase): diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py index 9fd3f294a1..f29cafc6c9 100644 --- a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py @@ -10,7 +10,16 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): - def setUp(self): + @patch('vllm.distributed.get_tensor_model_parallel_world_size') + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config' + ) + def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): + mock_get_tp_world_size.return_value = 1 + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 256}) + mock_get_current_vllm_config.return_value = mock_vllm_config self.method = TorchairAscendW4A8DynamicLinearMethod() self.method.group_size = 8 @@ -18,6 +27,13 @@ def test_get_weight(self): weight = self.method.get_weight(8, 32, torch.bfloat16) self.assertEqual(weight["weight"].dtype, torch.int8) self.assertEqual(weight["weight"].shape, (32, 8)) + # new quant version weight + self.method.new_quant_version = True + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (16, 8)) + self.assertEqual(weight["_packed_dim"], 0) + self.assertEqual(weight["_packed_factor"], 2) def test_get_pergroup_param(self): params = self.method.get_pergroup_param(8, 32, torch.bfloat16) @@ -29,6 +45,75 @@ def test_get_pergroup_param(self): self.assertEqual(params["weight_scale_second"].shape, (32, 1)) self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + # new quant version weight + self.method.new_quant_version = True + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="column") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 1)) + params = self.method.get_pergroup_param(8, + 32, + torch.bfloat16, + layer_type="row") + self.assertEqual(params["scale_bias"].dtype, torch.float32) + self.assertEqual(params["scale_bias"].shape, (32, 16)) + + @patch('torch_npu.npu_convert_weight_to_int4pack') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, + mock_npu_convert_weight): + mock_npu.side_effect = lambda: torch.zeros( + (1, 32), dtype=torch.float32) + mock_npu_convert_weight.return_value = torch.zeros((32, 4), + dtype=torch.int32) + # old quant version weight + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter(torch.zeros((32, 8), + dtype=torch.int8), + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset = torch.nn.Parameter(torch.empty_like( + layer.weight_scale.data), + requires_grad=False) + layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( + layer.weight_scale_second.data), + requires_grad=False) + self.method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "weight_scale_bias")) + self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) + self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.method.new_quant_version = True + new_layer = torch.nn.Module() + new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8), + dtype=torch.int8), + requires_grad=False) + new_layer.weight_scale = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset = torch.nn.Parameter(torch.empty_like( + new_layer.weight_scale.data), + requires_grad=False) + new_layer.weight_scale_second = torch.nn.Parameter(torch.ones( + (32, 1), dtype=torch.float32), + requires_grad=False) + new_layer.weight_offset_second = torch.nn.Parameter( + torch.empty_like(new_layer.weight_scale_second.data), + requires_grad=False) + new_layer.scale_bias = torch.nn.Parameter(torch.zeros( + (32, 1), dtype=torch.float32), + requires_grad=False) + self.method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.scale_bias.data.shape, (32, )) + self.assertTrue(hasattr(new_layer, "weight_scale_second")) + self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32)) class TestAscendW4A8DynamicFusedMoEMethod(TestBase): @@ -42,7 +127,9 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): ) @patch( 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group') - @patch("vllm_ascend.ascend_config.get_ascend_config") + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config' + ) @patch( 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group' ) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 130251cdc8..3858c1aa56 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -254,9 +254,22 @@ def create_weights( weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype) + + # Extract packing information (if present) + packed_dim = weight_dict.pop("_packed_dim", None) + packed_factor = weight_dict.pop("_packed_factor", None) + for weight_name, weight_param in weight_dict.items(): param = torch.nn.Parameter(weight_param, requires_grad=False) set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + + # Set packing attributes if the weight is packed + if packed_dim is not None and packed_factor is not None: + set_weight_attrs(param, { + "packed_dim": packed_dim, + "packed_factor": packed_factor + }) + layer.register_parameter(weight_name, param) set_weight_attrs(param, extra_weight_attrs) @@ -275,8 +288,17 @@ def create_weights( layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj scale_bias shape is [output_size, 16], + # others are [output_size, 1] + layer_type = "row" if isinstance(layer, + RowParallelLinear) else "others" + pergroup_dict = self.quant_method.get_pergroup_param( - input_size_per_partition, output_size_per_partition, params_dtype) + input_size_per_partition, + output_size_per_partition, + params_dtype, + layer_type=layer_type) for pergroup_name, pergroup_param in pergroup_dict.items(): param = torch.nn.Parameter(pergroup_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index b8bcc7831f..f8f58d2e5a 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -36,18 +36,42 @@ class AscendW4A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True - try: - self.group_size = get_current_vllm_config( - ).quant_config.quant_description.get("group_size", 256) - except AttributeError: - self.group_size = 256 - @staticmethod - def get_weight(input_size: int, output_size: int, + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + """Create weight parameters. + + For new quantization version (double int4 pack into int8), the output dimension + is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned + dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader. + """ + params_dict = {} + + if self.new_quant_version: + # double int4 pack into int8: output dimension is compressed + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + # Add packing information for vLLM's weight_loader + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + return params_dict @staticmethod @@ -59,8 +83,14 @@ def get_perchannel_param(output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: return {} - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + """ + Create per-group quantization parameters. + """ params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, @@ -76,17 +106,52 @@ def get_pergroup_param(self, input_size: int, output_size: int, input_size // self.group_size, dtype=params_dtype) + + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16], + # others are [output_size, 1] + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) return params_dict @staticmethod - def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, - per_group_scale: torch.Tensor): + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): + """ + Process the scale for second-level quantization. + + Args: + weight: weight tensor [k, n] (in new version, n is already compressed to n/2) + scale: first-level quantization scale [output_size] + per_group_scale: second-level per-group quantization scale [group_num, n_scale] + is_new_quant: whether it's the new quantization version (weight already compressed) + + Returns: + (antiquant_scale, bias): dequantization scale and bias (bias=None for new version) + """ k, n = weight.shape - group_num, n = per_group_scale.shape - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) - weight_high = weight_high.reshape(k, n) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + # Restore logical dimension for compressed weight + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + # NOTE: scale_bias is not used currently + # because in msmodelslim w4a8 uses symmetric quantization + + # TODO: support potential future asymmetric quantization antiquant_scale = (scale * per_group_scale).reshape(group_num, n) return antiquant_scale.npu(), bias @@ -114,11 +179,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data, layer.weight_scale.data, layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, ) - param = torch.nn.Parameter(scale_bias, requires_grad=False) - layer.register_parameter("weight_scale_bias", param) - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + + if self.new_quant_version: + # Process the loaded data based on layer type + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + # Convert to NPU-specific int4pack format + if self.new_quant_version: + # weights on disk are already in packed int4 format + # pack 4 int8(int4*2) to int32 + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + # weights are not compressed + # need to be packed via npu_convert_weight_to_int4pack + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) class AscendW4A8DynamicFusedMoEMethod: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 010d45da41..bdc1730c4a 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -85,8 +85,11 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index ab4987f015..9c0c1b5207 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -62,8 +62,11 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index 02deee8994..128cdbef08 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -39,18 +39,34 @@ class TorchairAscendW4A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True - try: - self.group_size = get_current_vllm_config( - ).quant_config.quant_description.get("group_size", 256) - except AttributeError: - self.group_size = 256 - @staticmethod - def get_weight(input_size: int, output_size: int, + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + params_dict = {} + + if self.new_quant_version: + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + return params_dict @staticmethod @@ -62,8 +78,11 @@ def get_perchannel_param(output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: return {} - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, @@ -79,17 +98,32 @@ def get_pergroup_param(self, input_size: int, output_size: int, input_size // self.group_size, dtype=params_dtype) + + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) return params_dict @staticmethod - def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, - per_group_scale: torch.Tensor): + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): k, n = weight.shape - group_num, n = per_group_scale.shape - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) - weight_high = weight_high.reshape(k, n) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) return antiquant_scale.npu(), bias @@ -117,11 +151,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data, layer.weight_scale.data, layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, ) - param = torch.nn.Parameter(scale_bias, requires_grad=False) - layer.register_parameter("weight_scale_bias", param) - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + + if self.new_quant_version: + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + if self.new_quant_version: + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) class TorchairAscendW4A8DynamicFusedMoEMethod: diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 23c4699d58..8485ba158b 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -777,8 +777,11 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: return {} @staticmethod