Skip to content

Commit 290b490

Browse files
committed
test(quant): add unit tests to LinearMethod for w4a8 dynamic v1.0.0
1 parent e565da8 commit 290b490

File tree

2 files changed

+155
-16
lines changed

2 files changed

+155
-16
lines changed

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,32 @@
99

1010
class TestAscendW4A8DynamicLinearMethod(TestBase):
1111

12-
def setUp(self):
13-
with patch(
14-
'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
15-
) as mock_get_current_vllm_config:
16-
mock_vllm_config = Mock()
17-
mock_vllm_config.quant_config = Mock(
18-
quant_description={"group_size": 256})
19-
mock_vllm_config.scheduler_config = Mock(
20-
max_num_batched_tokens=2048,
21-
max_model_len=2048,
22-
enable_chunked_prefill=False)
23-
mock_get_current_vllm_config.return_value = mock_vllm_config
24-
self.method = AscendW4A8DynamicLinearMethod()
25-
self.method.group_size = 8
12+
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
13+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
14+
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
15+
mock_get_tp_world_size.return_value = 1
16+
mock_vllm_config = Mock()
17+
mock_vllm_config.quant_config = Mock(
18+
quant_description={"group_size": 256})
19+
mock_vllm_config.scheduler_config = Mock(
20+
max_num_batched_tokens=2048,
21+
max_model_len=2048,
22+
enable_chunked_prefill=False)
23+
mock_get_current_vllm_config.return_value = mock_vllm_config
24+
self.method = AscendW4A8DynamicLinearMethod()
25+
self.method.group_size = 8
2626

2727
def test_get_weight(self):
2828
weight = self.method.get_weight(8, 32, torch.bfloat16)
2929
self.assertEqual(weight["weight"].dtype, torch.int8)
3030
self.assertEqual(weight["weight"].shape, (32, 8))
31+
# new quant version weight
32+
self.method.new_quant_version = True
33+
weight = self.method.get_weight(8, 32, torch.bfloat16)
34+
self.assertEqual(weight["weight"].dtype, torch.int8)
35+
self.assertEqual(weight["weight"].shape, (16, 8))
36+
self.assertEqual(weight["_packed_dim"], 0)
37+
self.assertEqual(weight["_packed_factor"], 2)
3138

3239
def test_get_pergroup_param(self):
3340
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
@@ -39,6 +46,64 @@ def test_get_pergroup_param(self):
3946
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
4047
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
4148
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
49+
# new quant version weight
50+
self.method.new_quant_version = True
51+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
52+
layer_type="column")
53+
self.assertEqual(params["scale_bias"].dtype, torch.float32)
54+
self.assertEqual(params["scale_bias"].shape, (32, 1))
55+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
56+
layer_type="row")
57+
self.assertEqual(params["scale_bias"].dtype, torch.float32)
58+
self.assertEqual(params["scale_bias"].shape, (32, 16))
59+
60+
@patch('torch_npu.npu_convert_weight_to_int4pack')
61+
@patch('torch.Tensor.npu')
62+
def test_process_weights_after_loading(self, mock_npu,
63+
mock_npu_convert_weight):
64+
mock_npu.side_effect = lambda: torch.zeros((1, 32), dtype=torch.float32)
65+
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
66+
dtype=torch.int32)
67+
# old quant version weight
68+
layer = torch.nn.Module()
69+
layer.weight = torch.nn.Parameter(torch.zeros((32, 8), dtype=torch.int8),
70+
requires_grad=False)
71+
layer.weight_scale = torch.nn.Parameter(torch.ones((32, 1),
72+
dtype=torch.float32),
73+
requires_grad=False)
74+
layer.weight_offset = torch.nn.Parameter(
75+
torch.empty_like(layer.weight_scale.data), requires_grad=False)
76+
layer.weight_scale_second = torch.nn.Parameter(
77+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
78+
layer.weight_offset_second = torch.nn.Parameter(
79+
torch.empty_like(layer.weight_scale_second.data),
80+
requires_grad=False)
81+
self.method.process_weights_after_loading(layer)
82+
self.assertTrue(hasattr(layer, "weight_scale_bias"))
83+
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
84+
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
85+
# new quant version weight
86+
self.method.new_quant_version = True
87+
new_layer = torch.nn.Module()
88+
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
89+
dtype=torch.int8),
90+
requires_grad=False)
91+
new_layer.weight_scale = torch.nn.Parameter(
92+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
93+
new_layer.weight_offset = torch.nn.Parameter(
94+
torch.empty_like(new_layer.weight_scale.data), requires_grad=False)
95+
new_layer.weight_scale_second = torch.nn.Parameter(
96+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
97+
new_layer.weight_offset_second = torch.nn.Parameter(
98+
torch.empty_like(new_layer.weight_scale_second.data),
99+
requires_grad=False)
100+
new_layer.scale_bias = torch.nn.Parameter(torch.zeros((32, 1),
101+
dtype=torch.float32),
102+
requires_grad=False)
103+
self.method.process_weights_after_loading(new_layer)
104+
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
105+
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
106+
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))
42107

43108

44109
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):

tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,28 @@
1010

1111
class TestAscendW4A8DynamicLinearMethod(TestBase):
1212

13-
def setUp(self):
13+
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
14+
@patch('vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config')
15+
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
16+
mock_get_tp_world_size.return_value = 1
17+
mock_vllm_config = Mock()
18+
mock_vllm_config.quant_config = Mock(
19+
quant_description={"group_size": 256})
20+
mock_get_current_vllm_config.return_value = mock_vllm_config
1421
self.method = TorchairAscendW4A8DynamicLinearMethod()
1522
self.method.group_size = 8
1623

1724
def test_get_weight(self):
1825
weight = self.method.get_weight(8, 32, torch.bfloat16)
1926
self.assertEqual(weight["weight"].dtype, torch.int8)
2027
self.assertEqual(weight["weight"].shape, (32, 8))
28+
# new quant version weight
29+
self.method.new_quant_version = True
30+
weight = self.method.get_weight(8, 32, torch.bfloat16)
31+
self.assertEqual(weight["weight"].dtype, torch.int8)
32+
self.assertEqual(weight["weight"].shape, (16, 8))
33+
self.assertEqual(weight["_packed_dim"], 0)
34+
self.assertEqual(weight["_packed_factor"], 2)
2135

2236
def test_get_pergroup_param(self):
2337
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
@@ -29,6 +43,64 @@ def test_get_pergroup_param(self):
2943
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
3044
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
3145
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
46+
# new quant version weight
47+
self.method.new_quant_version = True
48+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
49+
layer_type="column")
50+
self.assertEqual(params["scale_bias"].dtype, torch.float32)
51+
self.assertEqual(params["scale_bias"].shape, (32, 1))
52+
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
53+
layer_type="row")
54+
self.assertEqual(params["scale_bias"].dtype, torch.float32)
55+
self.assertEqual(params["scale_bias"].shape, (32, 16))
56+
57+
@patch('torch_npu.npu_convert_weight_to_int4pack')
58+
@patch('torch.Tensor.npu')
59+
def test_process_weights_after_loading(self, mock_npu,
60+
mock_npu_convert_weight):
61+
mock_npu.side_effect = lambda: torch.zeros((1, 32), dtype=torch.float32)
62+
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
63+
dtype=torch.int32)
64+
# old quant version weight
65+
layer = torch.nn.Module()
66+
layer.weight = torch.nn.Parameter(torch.zeros((32, 8), dtype=torch.int8),
67+
requires_grad=False)
68+
layer.weight_scale = torch.nn.Parameter(torch.ones((32, 1),
69+
dtype=torch.float32),
70+
requires_grad=False)
71+
layer.weight_offset = torch.nn.Parameter(
72+
torch.empty_like(layer.weight_scale.data), requires_grad=False)
73+
layer.weight_scale_second = torch.nn.Parameter(
74+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
75+
layer.weight_offset_second = torch.nn.Parameter(
76+
torch.empty_like(layer.weight_scale_second.data),
77+
requires_grad=False)
78+
self.method.process_weights_after_loading(layer)
79+
self.assertTrue(hasattr(layer, "weight_scale_bias"))
80+
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
81+
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
82+
# new quant version weight
83+
self.method.new_quant_version = True
84+
new_layer = torch.nn.Module()
85+
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
86+
dtype=torch.int8),
87+
requires_grad=False)
88+
new_layer.weight_scale = torch.nn.Parameter(
89+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
90+
new_layer.weight_offset = torch.nn.Parameter(
91+
torch.empty_like(new_layer.weight_scale.data), requires_grad=False)
92+
new_layer.weight_scale_second = torch.nn.Parameter(
93+
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
94+
new_layer.weight_offset_second = torch.nn.Parameter(
95+
torch.empty_like(new_layer.weight_scale_second.data),
96+
requires_grad=False)
97+
new_layer.scale_bias = torch.nn.Parameter(torch.zeros((32, 1),
98+
dtype=torch.float32),
99+
requires_grad=False)
100+
self.method.process_weights_after_loading(new_layer)
101+
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
102+
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
103+
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))
32104

33105

34106
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
@@ -42,7 +114,9 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
42114
)
43115
@patch(
44116
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
45-
@patch("vllm_ascend.ascend_config.get_ascend_config")
117+
@patch(
118+
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config'
119+
)
46120
@patch(
47121
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
48122
)

0 commit comments

Comments
 (0)