@@ -16,10 +16,9 @@ def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
16
16
mock_vllm_config = Mock ()
17
17
mock_vllm_config .quant_config = Mock (
18
18
quant_description = {"group_size" : 256 })
19
- mock_vllm_config .scheduler_config = Mock (
20
- max_num_batched_tokens = 2048 ,
21
- max_model_len = 2048 ,
22
- enable_chunked_prefill = False )
19
+ mock_vllm_config .scheduler_config = Mock (max_num_batched_tokens = 2048 ,
20
+ max_model_len = 2048 ,
21
+ enable_chunked_prefill = False )
23
22
mock_get_current_vllm_config .return_value = mock_vllm_config
24
23
self .method = AscendW4A8DynamicLinearMethod ()
25
24
self .method .group_size = 8
@@ -48,11 +47,15 @@ def test_get_pergroup_param(self):
48
47
self .assertEqual (params ["weight_offset_second" ].shape , (32 , 1 ))
49
48
# new quant version weight
50
49
self .method .new_quant_version = True
51
- params = self .method .get_pergroup_param (8 , 32 , torch .bfloat16 ,
50
+ params = self .method .get_pergroup_param (8 ,
51
+ 32 ,
52
+ torch .bfloat16 ,
52
53
layer_type = "column" )
53
54
self .assertEqual (params ["scale_bias" ].dtype , torch .float32 )
54
55
self .assertEqual (params ["scale_bias" ].shape , (32 , 1 ))
55
- params = self .method .get_pergroup_param (8 , 32 , torch .bfloat16 ,
56
+ params = self .method .get_pergroup_param (8 ,
57
+ 32 ,
58
+ torch .bfloat16 ,
56
59
layer_type = "row" )
57
60
self .assertEqual (params ["scale_bias" ].dtype , torch .float32 )
58
61
self .assertEqual (params ["scale_bias" ].shape , (32 , 16 ))
@@ -61,23 +64,27 @@ def test_get_pergroup_param(self):
61
64
@patch ('torch.Tensor.npu' )
62
65
def test_process_weights_after_loading (self , mock_npu ,
63
66
mock_npu_convert_weight ):
64
- mock_npu .side_effect = lambda : torch .zeros ((1 , 32 ), dtype = torch .float32 )
67
+ mock_npu .side_effect = lambda : torch .zeros (
68
+ (1 , 32 ), dtype = torch .float32 )
65
69
mock_npu_convert_weight .return_value = torch .zeros ((32 , 4 ),
66
- dtype = torch .int32 )
70
+ dtype = torch .int32 )
67
71
# old quant version weight
68
72
layer = torch .nn .Module ()
69
- layer .weight = torch .nn .Parameter (torch .zeros ((32 , 8 ), dtype = torch .int8 ),
73
+ layer .weight = torch .nn .Parameter (torch .zeros ((32 , 8 ),
74
+ dtype = torch .int8 ),
70
75
requires_grad = False )
71
- layer .weight_scale = torch .nn .Parameter (torch .ones (( 32 , 1 ),
72
- dtype = torch .float32 ),
76
+ layer .weight_scale = torch .nn .Parameter (torch .ones (
77
+ ( 32 , 1 ), dtype = torch .float32 ),
73
78
requires_grad = False )
74
- layer .weight_offset = torch .nn .Parameter (
75
- torch .empty_like (layer .weight_scale .data ), requires_grad = False )
76
- layer .weight_scale_second = torch .nn .Parameter (
77
- torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
78
- layer .weight_offset_second = torch .nn .Parameter (
79
- torch .empty_like (layer .weight_scale_second .data ),
80
- requires_grad = False )
79
+ layer .weight_offset = torch .nn .Parameter (torch .empty_like (
80
+ layer .weight_scale .data ),
81
+ requires_grad = False )
82
+ layer .weight_scale_second = torch .nn .Parameter (torch .ones (
83
+ (32 , 1 ), dtype = torch .float32 ),
84
+ requires_grad = False )
85
+ layer .weight_offset_second = torch .nn .Parameter (torch .empty_like (
86
+ layer .weight_scale_second .data ),
87
+ requires_grad = False )
81
88
self .method .process_weights_after_loading (layer )
82
89
self .assertTrue (hasattr (layer , "weight_scale_bias" ))
83
90
self .assertEqual (layer .weight_scale_bias .data .shape , (32 , ))
@@ -86,19 +93,22 @@ def test_process_weights_after_loading(self, mock_npu,
86
93
self .method .new_quant_version = True
87
94
new_layer = torch .nn .Module ()
88
95
new_layer .weight = torch .nn .Parameter (torch .zeros ((16 , 8 ),
89
- dtype = torch .int8 ),
96
+ dtype = torch .int8 ),
90
97
requires_grad = False )
91
- new_layer .weight_scale = torch .nn .Parameter (
92
- torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
93
- new_layer .weight_offset = torch .nn .Parameter (
94
- torch .empty_like (new_layer .weight_scale .data ), requires_grad = False )
95
- new_layer .weight_scale_second = torch .nn .Parameter (
96
- torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
98
+ new_layer .weight_scale = torch .nn .Parameter (torch .ones (
99
+ (32 , 1 ), dtype = torch .float32 ),
100
+ requires_grad = False )
101
+ new_layer .weight_offset = torch .nn .Parameter (torch .empty_like (
102
+ new_layer .weight_scale .data ),
103
+ requires_grad = False )
104
+ new_layer .weight_scale_second = torch .nn .Parameter (torch .ones (
105
+ (32 , 1 ), dtype = torch .float32 ),
106
+ requires_grad = False )
97
107
new_layer .weight_offset_second = torch .nn .Parameter (
98
108
torch .empty_like (new_layer .weight_scale_second .data ),
99
109
requires_grad = False )
100
- new_layer .scale_bias = torch .nn .Parameter (torch .zeros (( 32 , 1 ),
101
- dtype = torch .float32 ),
110
+ new_layer .scale_bias = torch .nn .Parameter (torch .zeros (
111
+ ( 32 , 1 ), dtype = torch .float32 ),
102
112
requires_grad = False )
103
113
self .method .process_weights_after_loading (new_layer )
104
114
self .assertEqual (new_layer .scale_bias .data .shape , (32 , ))
0 commit comments