Skip to content

Commit 8a59367

Browse files
[main][Feature] Support deepseek w4a8 quantization (#2172)
### What this PR does / why we need it? Supports Deepseek-R1 w4a8 quantization. Since R1 w4a8 uses mixed quantization, only the MOE layer uses w4a8_dynamic quantization, so we added the w4a8_dynamic.py file, which includes the AscendW4A8DynamicFusedMoEMethod class. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` and `tests/ut/quantization/test_quantizer.py` Adding e2e case in `tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC` to test deepseek w4a8_dynamic quantized model #### 1.How to get weights using Modelslim ##### Installation steps Use the branch master, the commit id is: 298e175d69b3b855111a1e09bbe2fcd12fdb4e24 git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### The required transformers environment transformers>=4.48.2 ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} --mindie_format ##### Adapt to vllm-ascend Since mindie_format generates mindie format, some adaptation modifications are needed for vllm-ascend to use it: `quant_model_description_w8a8_dynamic.json` rename to `quant_model_description.json`, and add `"group_size": 256` Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3`; `quantization_config` is removed; tips:The group_size and weights match. If the w4a8 weights are not generated using msmodelslim, you can check the group_size in quantization_config in config.json. #### 2.How to run w4a8 ##### a.How to run eager mode export VLLM_USE_V1=1 # v1 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager ##### b.How to run graph mode export VLLM_USE_V1=1 # v1 export HCCL_BUFFSIZE=1024 python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@c494f96 --------- Signed-off-by: Wang Kunpeng <[email protected]>
1 parent e31b31f commit 8a59367

File tree

9 files changed

+483
-21
lines changed

9 files changed

+483
-21
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ jobs:
283283
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
284284
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
285285
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
286+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
286287
pytest -sv tests/e2e/multicard/test_data_parallel.py
287288
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
288289
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,28 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
209209
quantization="ascend",
210210
) as vllm_model:
211211
vllm_model.generate_greedy(example_prompts, max_tokens)
212+
213+
214+
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
215+
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
216+
prompts = [
217+
"Hello, my name is",
218+
]
219+
max_tokens = 5
220+
with VllmRunner(
221+
snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"),
222+
dtype="auto",
223+
tensor_parallel_size=2,
224+
quantization="ascend",
225+
enforce_eager=True,
226+
enable_expert_parallel=True,
227+
additional_config={
228+
"torchair_graph_config": {
229+
"enabled": False,
230+
},
231+
"ascend_scheduler_config": {
232+
"enabled": True,
233+
}
234+
},
235+
) as vllm_model:
236+
vllm_model.generate_greedy(prompts, max_tokens)

tests/ut/quantization/test_quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tests.ut.base import TestBase
44
from vllm_ascend.quantization.quant_config import AscendQuantConfig
55
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
6+
W4A8DYNAMICQuantizer,
67
W8A8Quantizer)
78

89
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
@@ -120,3 +121,25 @@ def test_build_attention_method(self):
120121
result = self.quantizer.build_attention_method()
121122
mock_linear.assert_called_once_with()
122123
self.assertIsInstance(result, MagicMock)
124+
125+
126+
class TestW4A8DYNAMICQuantizer(TestBase):
127+
128+
def setUp(self):
129+
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
130+
131+
def test_build_linear_method(self):
132+
with patch(
133+
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
134+
return_value=MagicMock()) as mock_linear:
135+
result = self.quantizer.build_linear_method()
136+
mock_linear.assert_called_once_with()
137+
self.assertIsInstance(result, MagicMock)
138+
139+
def test_build_moe_method(self):
140+
with patch(
141+
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
142+
return_value=MagicMock()) as mock_fused_moe:
143+
result = self.quantizer.build_moe_method()
144+
mock_fused_moe.assert_called_once_with()
145+
self.assertIsInstance(result, MagicMock)

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from unittest.mock import Mock, patch
2+
13
import torch
24

35
from tests.ut.base import TestBase
4-
from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod
6+
from vllm_ascend.quantization.w4a8_dynamic import (
7+
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
58

69

710
class TestAscendW4A8DynamicLinearMethod(TestBase):
@@ -25,3 +28,82 @@ def test_get_pergroup_param(self):
2528
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
2629
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
2730
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
31+
32+
33+
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
34+
35+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
36+
@patch("vllm_ascend.ascend_config.get_ascend_config")
37+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
38+
@patch('torch.distributed.get_rank', return_value=0)
39+
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
40+
mock_get_ep_group):
41+
mock_ascend_config = Mock()
42+
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
43+
mock_get_ascend_config.return_value = mock_ascend_config
44+
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
45+
46+
def test_get_weight(self):
47+
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
48+
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
49+
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
50+
51+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
52+
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
53+
mock_vllm_config = Mock()
54+
mock_vllm_config.quant_config = Mock(
55+
quant_description={"group_size": 2})
56+
mock_get_current_vllm_config.return_value = mock_vllm_config
57+
param_dict = self.quant_method.get_dynamic_quant_param(
58+
8, 4, 14, torch.bfloat16)
59+
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
60+
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
61+
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
62+
torch.bfloat16)
63+
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
64+
(8, 8, 7))
65+
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
66+
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
67+
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
68+
torch.bfloat16)
69+
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
70+
(8, 14, 2))
71+
72+
@patch('torch_npu.npu_quantize')
73+
@patch('torch.Tensor.npu')
74+
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
75+
layer = torch.nn.Module()
76+
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
77+
dtype=torch.int8),
78+
requires_grad=False)
79+
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
80+
dtype=torch.int8),
81+
requires_grad=False)
82+
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
83+
(8, 8, 1), dtype=torch.bfloat16),
84+
requires_grad=False)
85+
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
86+
(8, 8, 1), dtype=torch.bfloat16),
87+
requires_grad=False)
88+
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
89+
(8, 8, 7), dtype=torch.bfloat16),
90+
requires_grad=False)
91+
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
92+
(8, 14, 1), dtype=torch.bfloat16),
93+
requires_grad=False)
94+
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
95+
(8, 14, 1), dtype=torch.bfloat16),
96+
requires_grad=False)
97+
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
98+
(8, 14, 2), dtype=torch.bfloat16),
99+
requires_grad=False)
100+
101+
mock_npu.return_value = torch.Tensor()
102+
mock_npu_quantize.return_value = torch.Tensor()
103+
self.quant_method.process_weights_after_loading(layer)
104+
self.assertTrue(hasattr(layer, "w13_scale_bias"))
105+
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
106+
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
107+
self.assertTrue(hasattr(layer, "w2_scale_bias"))
108+
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
109+
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ def load_weights(self, weights: Iterable[tuple[str,
905905
for name, loaded_weight in weights:
906906
if "rotary_emb.inv_freq" in name:
907907
continue
908+
if "module" in name:
909+
continue
908910

909911
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
910912
if spec_layer is not None:

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ def create_weights(
302302
param = torch.nn.Parameter(param_value, requires_grad=False)
303303
layer.register_parameter(param_key, param)
304304
set_weight_attrs(param, extra_weight_attrs)
305+
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
306+
setattr(param, "quant_method",
307+
FusedMoeWeightScaleSupported.GROUP.value)
305308

306309
def apply(
307310
self,
@@ -348,4 +351,4 @@ def __init__(self, quant_config: AscendQuantConfig, prefix: str,
348351
packed_modules_mapping: Dict[str, Any]) -> None:
349352
self.quantizer = AscendQuantizer.get_quantizer(
350353
quant_config.quant_description, prefix, packed_modules_mapping)
351-
self.quant_method = self.quantizer.build_linear_method()
354+
self.quant_method = self.quantizer.build_linear_method()

vllm_ascend/quantization/quantizer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
2626
wrapper_vocab_parallel_embedding_init)
27-
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
27+
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
28+
AscendW4A8DynamicLinearMethod)
2829
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
2930
AscendW8A8LinearMethod)
3031
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -97,12 +98,15 @@ def apply_patch(target_module, target_function, wrappers):
9798
if target_function is not None:
9899
setattr(original_module, target_function, candidate)
99100

100-
for key, value in sys.modules.copy().items():
101-
if (target_function is not None
102-
and hasattr(value, target_function)
103-
and id(getattr(value,
104-
target_function)) == original_function_id):
105-
setattr(value, target_function, candidate)
101+
for _, value in sys.modules.copy().items():
102+
if target_function is None:
103+
continue
104+
try:
105+
attr = getattr(value, target_function, None)
106+
if attr is not None and id(attr) == original_function_id:
107+
setattr(value, target_function, candidate)
108+
except ImportError:
109+
continue
106110

107111
@staticmethod
108112
def parse_path(module_path, function_name, create_dummy):
@@ -268,6 +272,10 @@ class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
268272
def build_linear_method():
269273
return AscendW4A8DynamicLinearMethod()
270274

275+
@staticmethod
276+
def build_moe_method():
277+
return AscendW4A8DynamicFusedMoEMethod()
278+
271279

272280
class W8A8Quantizer(VLLMAscendQuantizer):
273281

0 commit comments

Comments
 (0)