Skip to content

Commit 61605d4

Browse files
committed
style: format code with pre-commit hooks
Signed-off-by: Anionex <[email protected]>
1 parent dc88722 commit 61605d4

File tree

7 files changed

+106
-71
lines changed

7 files changed

+106
-71
lines changed

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
1616
mock_vllm_config = Mock()
1717
mock_vllm_config.quant_config = Mock(
1818
quant_description={"group_size": 256})
19-
mock_vllm_config.scheduler_config = Mock(
20-
max_num_batched_tokens=2048,
21-
max_model_len=2048,
22-
enable_chunked_prefill=False)
19+
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
20+
max_model_len=2048,
21+
enable_chunked_prefill=False)
2322
mock_get_current_vllm_config.return_value = mock_vllm_config
2423
self.method = AscendW4A8DynamicLinearMethod()
2524
self.method.group_size = 8
@@ -48,11 +47,15 @@ def test_get_pergroup_param(self):
4847
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
4948
# new quant version weight
5049
self.method.new_quant_version = True
51-
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
50+
params = self.method.get_pergroup_param(8,
51+
32,
52+
torch.bfloat16,
5253
layer_type="column")
5354
self.assertEqual(params["scale_bias"].dtype, torch.float32)
5455
self.assertEqual(params["scale_bias"].shape, (32, 1))
55-
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
56+
params = self.method.get_pergroup_param(8,
57+
32,
58+
torch.bfloat16,
5659
layer_type="row")
5760
self.assertEqual(params["scale_bias"].dtype, torch.float32)
5861
self.assertEqual(params["scale_bias"].shape, (32, 16))
@@ -61,23 +64,27 @@ def test_get_pergroup_param(self):
6164
@patch('torch.Tensor.npu')
6265
def test_process_weights_after_loading(self, mock_npu,
6366
mock_npu_convert_weight):
64-
mock_npu.side_effect = lambda: torch.zeros((1, 32), dtype=torch.float32)
67+
mock_npu.side_effect = lambda: torch.zeros(
68+
(1, 32), dtype=torch.float32)
6569
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
66-
dtype=torch.int32)
70+
dtype=torch.int32)
6771
# old quant version weight
6872
layer = torch.nn.Module()
69-
layer.weight = torch.nn.Parameter(torch.zeros((32, 8), dtype=torch.int8),
73+
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
74+
dtype=torch.int8),
7075
requires_grad=False)
71-
layer.weight_scale = torch.nn.Parameter(torch.ones((32, 1),
72-
dtype=torch.float32),
76+
layer.weight_scale = torch.nn.Parameter(torch.ones(
77+
(32, 1), dtype=torch.float32),
7378
requires_grad=False)
74-
layer.weight_offset = torch.nn.Parameter(
75-
torch.empty_like(layer.weight_scale.data), requires_grad=False)
76-
layer.weight_scale_second = torch.nn.Parameter(
77-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
78-
layer.weight_offset_second = torch.nn.Parameter(
79-
torch.empty_like(layer.weight_scale_second.data),
80-
requires_grad=False)
79+
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
80+
layer.weight_scale.data),
81+
requires_grad=False)
82+
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
83+
(32, 1), dtype=torch.float32),
84+
requires_grad=False)
85+
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
86+
layer.weight_scale_second.data),
87+
requires_grad=False)
8188
self.method.process_weights_after_loading(layer)
8289
self.assertTrue(hasattr(layer, "weight_scale_bias"))
8390
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
@@ -86,19 +93,22 @@ def test_process_weights_after_loading(self, mock_npu,
8693
self.method.new_quant_version = True
8794
new_layer = torch.nn.Module()
8895
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
89-
dtype=torch.int8),
96+
dtype=torch.int8),
9097
requires_grad=False)
91-
new_layer.weight_scale = torch.nn.Parameter(
92-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
93-
new_layer.weight_offset = torch.nn.Parameter(
94-
torch.empty_like(new_layer.weight_scale.data), requires_grad=False)
95-
new_layer.weight_scale_second = torch.nn.Parameter(
96-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
98+
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
99+
(32, 1), dtype=torch.float32),
100+
requires_grad=False)
101+
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
102+
new_layer.weight_scale.data),
103+
requires_grad=False)
104+
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
105+
(32, 1), dtype=torch.float32),
106+
requires_grad=False)
97107
new_layer.weight_offset_second = torch.nn.Parameter(
98108
torch.empty_like(new_layer.weight_scale_second.data),
99109
requires_grad=False)
100-
new_layer.scale_bias = torch.nn.Parameter(torch.zeros((32, 1),
101-
dtype=torch.float32),
110+
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
111+
(32, 1), dtype=torch.float32),
102112
requires_grad=False)
103113
self.method.process_weights_after_loading(new_layer)
104114
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))

tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
class TestAscendW4A8DynamicLinearMethod(TestBase):
1212

1313
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
14-
@patch('vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config')
14+
@patch(
15+
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
16+
)
1517
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
1618
mock_get_tp_world_size.return_value = 1
1719
mock_vllm_config = Mock()
@@ -45,11 +47,15 @@ def test_get_pergroup_param(self):
4547
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
4648
# new quant version weight
4749
self.method.new_quant_version = True
48-
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
50+
params = self.method.get_pergroup_param(8,
51+
32,
52+
torch.bfloat16,
4953
layer_type="column")
5054
self.assertEqual(params["scale_bias"].dtype, torch.float32)
5155
self.assertEqual(params["scale_bias"].shape, (32, 1))
52-
params = self.method.get_pergroup_param(8, 32, torch.bfloat16,
56+
params = self.method.get_pergroup_param(8,
57+
32,
58+
torch.bfloat16,
5359
layer_type="row")
5460
self.assertEqual(params["scale_bias"].dtype, torch.float32)
5561
self.assertEqual(params["scale_bias"].shape, (32, 16))
@@ -58,23 +64,27 @@ def test_get_pergroup_param(self):
5864
@patch('torch.Tensor.npu')
5965
def test_process_weights_after_loading(self, mock_npu,
6066
mock_npu_convert_weight):
61-
mock_npu.side_effect = lambda: torch.zeros((1, 32), dtype=torch.float32)
67+
mock_npu.side_effect = lambda: torch.zeros(
68+
(1, 32), dtype=torch.float32)
6269
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
63-
dtype=torch.int32)
70+
dtype=torch.int32)
6471
# old quant version weight
6572
layer = torch.nn.Module()
66-
layer.weight = torch.nn.Parameter(torch.zeros((32, 8), dtype=torch.int8),
73+
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
74+
dtype=torch.int8),
6775
requires_grad=False)
68-
layer.weight_scale = torch.nn.Parameter(torch.ones((32, 1),
69-
dtype=torch.float32),
76+
layer.weight_scale = torch.nn.Parameter(torch.ones(
77+
(32, 1), dtype=torch.float32),
7078
requires_grad=False)
71-
layer.weight_offset = torch.nn.Parameter(
72-
torch.empty_like(layer.weight_scale.data), requires_grad=False)
73-
layer.weight_scale_second = torch.nn.Parameter(
74-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
75-
layer.weight_offset_second = torch.nn.Parameter(
76-
torch.empty_like(layer.weight_scale_second.data),
77-
requires_grad=False)
79+
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
80+
layer.weight_scale.data),
81+
requires_grad=False)
82+
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
83+
(32, 1), dtype=torch.float32),
84+
requires_grad=False)
85+
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
86+
layer.weight_scale_second.data),
87+
requires_grad=False)
7888
self.method.process_weights_after_loading(layer)
7989
self.assertTrue(hasattr(layer, "weight_scale_bias"))
8090
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
@@ -83,19 +93,22 @@ def test_process_weights_after_loading(self, mock_npu,
8393
self.method.new_quant_version = True
8494
new_layer = torch.nn.Module()
8595
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
86-
dtype=torch.int8),
96+
dtype=torch.int8),
8797
requires_grad=False)
88-
new_layer.weight_scale = torch.nn.Parameter(
89-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
90-
new_layer.weight_offset = torch.nn.Parameter(
91-
torch.empty_like(new_layer.weight_scale.data), requires_grad=False)
92-
new_layer.weight_scale_second = torch.nn.Parameter(
93-
torch.ones((32, 1), dtype=torch.float32), requires_grad=False)
98+
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
99+
(32, 1), dtype=torch.float32),
100+
requires_grad=False)
101+
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
102+
new_layer.weight_scale.data),
103+
requires_grad=False)
104+
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
105+
(32, 1), dtype=torch.float32),
106+
requires_grad=False)
94107
new_layer.weight_offset_second = torch.nn.Parameter(
95108
torch.empty_like(new_layer.weight_scale_second.data),
96109
requires_grad=False)
97-
new_layer.scale_bias = torch.nn.Parameter(torch.zeros((32, 1),
98-
dtype=torch.float32),
110+
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
111+
(32, 1), dtype=torch.float32),
99112
requires_grad=False)
100113
self.method.process_weights_after_loading(new_layer)
101114
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))

vllm_ascend/quantization/quant_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,16 @@ def create_weights(
288288
layer.register_parameter(perchannel_name, param)
289289
set_weight_attrs(param, extra_weight_attrs)
290290

291-
# NOTE: In w4a8 quantization implementation,
292-
# for down_proj and o_proj scale_bias shape is [output_size, 16],
291+
# NOTE: In w4a8 quantization implementation,
292+
# for down_proj and o_proj scale_bias shape is [output_size, 16],
293293
# others are [output_size, 1]
294-
layer_type = "row" if isinstance(layer, RowParallelLinear) else "others"
295-
294+
layer_type = "row" if isinstance(layer,
295+
RowParallelLinear) else "others"
296+
296297
pergroup_dict = self.quant_method.get_pergroup_param(
297-
input_size_per_partition, output_size_per_partition, params_dtype,
298+
input_size_per_partition,
299+
output_size_per_partition,
300+
params_dtype,
298301
layer_type=layer_type)
299302
for pergroup_name, pergroup_param in pergroup_dict.items():
300303
param = torch.nn.Parameter(pergroup_param, requires_grad=False)

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ class AscendW4A8DynamicLinearMethod:
3636

3737
def __init__(self):
3838
self.transpose_weight = True
39-
39+
4040
vllm_config = get_current_vllm_config()
4141
self.group_size = vllm_config.quant_config.quant_description.get(
4242
"group_size", 256)
4343
quant_version = vllm_config.quant_config.quant_description.get(
4444
"version", "0")
4545
self.new_quant_version = quant_version == "1.0.0"
46-
46+
4747
from vllm.distributed import get_tensor_model_parallel_world_size
4848
self.tp_size = get_tensor_model_parallel_world_size()
4949

@@ -83,8 +83,10 @@ def get_perchannel_param(output_size: int,
8383
params_dtype: torch.dtype) -> Dict[str, Any]:
8484
return {}
8585

86-
def get_pergroup_param(self, input_size: int, output_size: int,
87-
params_dtype: torch.dtype,
86+
def get_pergroup_param(self,
87+
input_size: int,
88+
output_size: int,
89+
params_dtype: torch.dtype,
8890
layer_type: Optional[str] = None) -> Dict[str, Any]:
8991
"""
9092
Create per-group quantization parameters.
@@ -105,12 +107,12 @@ def get_pergroup_param(self, input_size: int, output_size: int,
105107
self.group_size,
106108
dtype=params_dtype)
107109

108-
# NOTE: In w4a8 quantization implementation,
109-
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
110+
# NOTE: In w4a8 quantization implementation,
111+
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
110112
# others are [output_size, 1]
111113
if self.new_quant_version:
112114
scale_bias_dim = 16 if layer_type == "row" else 1
113-
115+
114116
params_dict["scale_bias"] = torch.empty(output_size,
115117
scale_bias_dim,
116118
dtype=torch.float32)
@@ -147,7 +149,7 @@ def process_scale_second(weight: torch.Tensor,
147149
weight_high = weight_high.reshape(k, n)
148150
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
149151
# NOTE: scale_bias is not used currently
150-
# because in msmodelslim w4a8 uses symmetric quantization
152+
# because in msmodelslim w4a8 uses symmetric quantization
151153

152154
# TODO: support potential future asymmetric quantization
153155
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)

vllm_ascend/quantization/w8a8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def get_perchannel_param(
8585
dtype=params_dtype)
8686
return params_dict
8787

88-
def get_pergroup_param(self, input_size: int, output_size: int,
88+
def get_pergroup_param(self,
89+
input_size: int,
90+
output_size: int,
8991
params_dtype: torch.dtype,
9092
layer_type: Optional[str] = None) -> Dict[str, Any]:
9193
return {}

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def get_perchannel_param(
6262
dtype=params_dtype)
6363
return params_dict
6464

65-
def get_pergroup_param(self, input_size: int, output_size: int,
65+
def get_pergroup_param(self,
66+
input_size: int,
67+
output_size: int,
6668
params_dtype: torch.dtype,
6769
layer_type: Optional[str] = None) -> Dict[str, Any]:
6870
return {}

vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ class TorchairAscendW4A8DynamicLinearMethod:
3939

4040
def __init__(self):
4141
self.transpose_weight = True
42-
42+
4343
vllm_config = get_current_vllm_config()
4444
self.group_size = vllm_config.quant_config.quant_description.get(
4545
"group_size", 256)
4646
quant_version = vllm_config.quant_config.quant_description.get(
4747
"version", "0")
4848
self.new_quant_version = quant_version == "1.0.0"
49-
49+
5050
from vllm.distributed import get_tensor_model_parallel_world_size
5151
self.tp_size = get_tensor_model_parallel_world_size()
5252

@@ -78,8 +78,10 @@ def get_perchannel_param(output_size: int,
7878
params_dtype: torch.dtype) -> Dict[str, Any]:
7979
return {}
8080

81-
def get_pergroup_param(self, input_size: int, output_size: int,
82-
params_dtype: torch.dtype,
81+
def get_pergroup_param(self,
82+
input_size: int,
83+
output_size: int,
84+
params_dtype: torch.dtype,
8385
layer_type: Optional[str] = None) -> Dict[str, Any]:
8486
params_dict = {}
8587
params_dict["weight_scale"] = torch.empty(output_size,
@@ -166,7 +168,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
166168
if self.new_quant_version:
167169
assert layer.weight.data.shape[-1] % 4 == 0, \
168170
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
169-
layer.weight.data = layer.weight.data.view(torch.int32).contiguous()
171+
layer.weight.data = layer.weight.data.view(
172+
torch.int32).contiguous()
170173
else:
171174
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
172175
layer.weight.data.to(torch.int32))

0 commit comments

Comments
 (0)