Skip to content

Commit 859e861

Browse files
[main][quantization] Support deepseek w4a8 per-channel quantization (#3011)
### What this PR does / why we need it? 1.Support deepseek w4a8 per-channel quantization 2.The eager mode supports converting weights to the NZ format ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? #### How to get weights using Modelslim ##### Installation steps git clone https://gitcode.com/Ascend/msit.git cd msit/msmodelslim bash install.sh ##### Generate w4a8 per-channel weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 --------- Signed-off-by: Wang Kunpeng <[email protected]>
1 parent e9359bd commit 859e861

File tree

6 files changed

+293
-190
lines changed

6 files changed

+293
-190
lines changed

docs/source/user_guide/feature_guide/quantization.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,19 @@ Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_2
108108
109109
### 3. When converting deepseek series models with modelslim, what should you pay attention?
110110
111-
When using the weight generated by modelslim with the `--dynamic` parameter, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results.
111+
When the mla portion of the weights used `W8A8_DYNAMIC` quantization, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results.
112112
113113
The operation steps are as follows:
114114
115115
1. Search in the CANN package directory used, for example:
116116
find /usr/local/Ascend/ -name fusion_config.json
117117
118-
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
118+
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
119119
120120
```bash
121121
{
122122
"Switch":{
123123
"GraphFusion":{
124124
"AddRmsNormDynamicQuantFusionPass":"off",
125+
"MultiAddRmsNormDynamicQuantFusionPass":"off",
125126
```

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
3636
]
3737

38+
DEEPSEEK_W4A8_MODELS = [
39+
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
40+
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
41+
]
42+
3843

3944
def test_models_distributed_QwQ():
4045
example_prompts = [
@@ -109,14 +114,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
109114
vllm_model.generate_greedy(example_prompts, max_tokens)
110115

111116

117+
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
112118
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
113-
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
119+
def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
114120
prompts = [
115121
"Hello, my name is",
116122
]
117123
max_tokens = 5
118124
with VllmRunner(
119-
snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"),
125+
snapshot_download(model),
120126
dtype="auto",
121127
tensor_parallel_size=2,
122128
quantization="ascend",

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
from unittest.mock import Mock, patch
32

43
import torch
@@ -95,19 +94,19 @@ def test_get_dynamic_quant_param(self):
9594
# old quant version weight
9695
param_dict = self.quant_method.get_dynamic_quant_param(
9796
self.experts, self.input_size, self.output_size, torch.bfloat16)
98-
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
97+
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
9998
self.assertEqual(param_dict["w13_weight_scale"].shape,
10099
(self.experts, 2 * self.input_size, 1))
101100
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
102-
torch.bfloat16)
101+
torch.float32)
103102
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
104103
(self.experts, 2 * self.input_size,
105104
self.output_size // self.group_size))
106-
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
105+
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
107106
self.assertEqual(param_dict["w2_weight_scale"].shape,
108107
(self.experts, self.output_size, 1))
109108
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
110-
torch.bfloat16)
109+
torch.float32)
111110
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
112111
(self.experts, self.output_size,
113112
self.input_size // self.group_size))
@@ -119,40 +118,87 @@ def test_get_dynamic_quant_param(self):
119118
self.assertEqual(
120119
param_dict["w2_scale_bias"].shape,
121120
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
121+
# per-channel weight
122+
self.quant_method.is_per_channel_weight = True
123+
param_dict = self.quant_method.get_dynamic_quant_param(
124+
self.experts, self.input_size, self.output_size, torch.bfloat16)
125+
pergroup_param = [
126+
"w13_weight_scale_second", "w13_weight_offset_second",
127+
"w2_weight_scale_second", "w2_weight_offset_second"
128+
]
129+
is_contains = any(key in param_dict for key in pergroup_param)
130+
self.assertFalse(is_contains)
122131

123-
@patch('torch_npu.npu_quantize')
124-
@patch('torch.Tensor.npu')
125-
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
126-
# old quant version weight
132+
def build_layer(self,
133+
is_new_quant_version=True,
134+
is_per_channel_weight=False):
127135
layer = torch.nn.Module()
128-
layer.w13_weight = torch.nn.Parameter(torch.zeros(
129-
(self.experts, 2 * self.input_size, self.output_size),
130-
dtype=torch.int8),
131-
requires_grad=False)
132-
layer.w2_weight = torch.nn.Parameter(torch.zeros(
133-
(self.experts, self.output_size, self.input_size),
134-
dtype=torch.int8),
135-
requires_grad=False)
136+
if is_new_quant_version:
137+
layer.w13_weight = torch.nn.Parameter(torch.zeros(
138+
(self.experts, self.input_size, self.output_size),
139+
dtype=torch.int8),
140+
requires_grad=False)
141+
layer.w2_weight = torch.nn.Parameter(torch.zeros(
142+
(self.experts, self.output_size // 2, self.input_size),
143+
dtype=torch.int8),
144+
requires_grad=False)
145+
w13_scale_bias = torch.zeros(
146+
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
147+
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
148+
requires_grad=False)
149+
w2_scale_bias = torch.zeros((self.experts, self.output_size,
150+
16 // self.quant_method.tp_size),
151+
dtype=torch.float32)
152+
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
153+
requires_grad=False)
154+
else:
155+
layer.w13_weight = torch.nn.Parameter(torch.zeros(
156+
(self.experts, 2 * self.input_size, self.output_size),
157+
dtype=torch.int8),
158+
requires_grad=False)
159+
layer.w2_weight = torch.nn.Parameter(torch.zeros(
160+
(self.experts, self.output_size, self.input_size),
161+
dtype=torch.int8),
162+
requires_grad=False)
136163
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
137-
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
164+
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
138165
requires_grad=False)
139-
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
140-
(self.experts, 2 * self.input_size,
141-
self.output_size // self.group_size),
142-
dtype=torch.bfloat16),
143-
requires_grad=False)
144166
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
145-
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
167+
(self.experts, self.output_size, 1), dtype=torch.float32),
146168
requires_grad=False)
147-
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
148-
(self.experts, self.output_size,
149-
self.input_size // self.group_size),
150-
dtype=torch.bfloat16),
151-
requires_grad=False)
152-
new_layer = copy.deepcopy(layer)
169+
if not is_per_channel_weight:
170+
layer.w13_weight_scale_second = torch.nn.Parameter(
171+
torch.ones((self.experts, 2 * self.input_size,
172+
self.output_size // self.group_size),
173+
dtype=torch.float32),
174+
requires_grad=False)
175+
layer.w13_weight_offset_second = torch.nn.Parameter(
176+
torch.empty_like(layer.w13_weight_scale_second.data),
177+
requires_grad=False)
178+
layer.w2_weight_scale_second = torch.nn.Parameter(
179+
torch.ones((self.experts, self.output_size,
180+
self.input_size // self.group_size),
181+
dtype=torch.float32),
182+
requires_grad=False)
183+
layer.w2_weight_offset_second = torch.nn.Parameter(
184+
torch.empty_like(layer.w2_weight_scale_second.data),
185+
requires_grad=False)
186+
return layer
153187

188+
@patch('torch_npu.npu_format_cast')
189+
@patch('torch_npu.npu_quantize')
190+
@patch('torch.Tensor.npu')
191+
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
192+
mock_npu_format_cast):
154193
mock_npu.return_value = torch.Tensor()
155194
mock_npu_quantize.return_value = torch.Tensor()
195+
196+
def func_by_args(weight, num_format):
197+
return weight
198+
199+
mock_npu_format_cast.side_effect = func_by_args
200+
# old quant version weight
201+
layer = self.build_layer(is_new_quant_version=False)
156202
self.quant_method.process_weights_after_loading(layer)
157203
self.assertTrue(hasattr(layer, "w13_scale_bias"))
158204
self.assertEqual(layer.w13_scale_bias.data.shape,
@@ -164,23 +210,17 @@ def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
164210
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
165211
# new quant version weight
166212
self.quant_method.new_quant_version = True
167-
new_layer.w13_weight.data = torch.zeros(
168-
(self.experts, self.input_size, self.output_size),
169-
dtype=torch.int8)
170-
new_layer.w2_weight.data = torch.zeros(
171-
(self.experts, self.output_size // 2, self.input_size),
172-
dtype=torch.int8)
173-
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
174-
dtype=torch.float32)
175-
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
176-
requires_grad=False)
177-
w2_scale_bias = torch.zeros(
178-
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
179-
dtype=torch.float32)
180-
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
181-
requires_grad=False)
213+
new_layer = self.build_layer(is_new_quant_version=True)
182214
self.quant_method.process_weights_after_loading(new_layer)
183215
self.assertEqual(new_layer.w13_scale_bias.data.shape,
184216
(self.experts, 2 * self.input_size))
185217
self.assertEqual(new_layer.w2_scale_bias.data.shape,
186218
(self.experts, self.output_size))
219+
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
220+
# per-channel weight
221+
self.quant_method.is_per_channel_weight = True
222+
per_channel_layer = self.build_layer(is_new_quant_version=True,
223+
is_per_channel_weight=True)
224+
self.quant_method.process_weights_after_loading(per_channel_layer)
225+
self.assertEqual(new_layer.w13_scale_bias.data.shape,
226+
(self.experts, 2 * self.input_size))

0 commit comments

Comments
 (0)