9
9
10
10
class TestAscendW4A8DynamicLinearMethod (TestBase ):
11
11
12
- def setUp ( self ):
13
- with patch (
14
- 'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
15
- ) as mock_get_current_vllm_config :
16
- mock_vllm_config = Mock ()
17
- mock_vllm_config .quant_config = Mock (
18
- quant_description = {"group_size" : 256 })
19
- mock_vllm_config .scheduler_config = Mock (
20
- max_num_batched_tokens = 2048 ,
21
- max_model_len = 2048 ,
22
- enable_chunked_prefill = False )
23
- mock_get_current_vllm_config .return_value = mock_vllm_config
24
- self .method = AscendW4A8DynamicLinearMethod ()
25
- self .method .group_size = 8
12
+ @ patch ( 'vllm.distributed.get_tensor_model_parallel_world_size' )
13
+ @ patch ('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config' )
14
+ def setUp ( self , mock_get_current_vllm_config , mock_get_tp_world_size ):
15
+ mock_get_tp_world_size . return_value = 1
16
+ mock_vllm_config = Mock ()
17
+ mock_vllm_config .quant_config = Mock (
18
+ quant_description = {"group_size" : 256 })
19
+ mock_vllm_config .scheduler_config = Mock (
20
+ max_num_batched_tokens = 2048 ,
21
+ max_model_len = 2048 ,
22
+ enable_chunked_prefill = False )
23
+ mock_get_current_vllm_config .return_value = mock_vllm_config
24
+ self .method = AscendW4A8DynamicLinearMethod ()
25
+ self .method .group_size = 8
26
26
27
27
def test_get_weight (self ):
28
28
weight = self .method .get_weight (8 , 32 , torch .bfloat16 )
29
29
self .assertEqual (weight ["weight" ].dtype , torch .int8 )
30
30
self .assertEqual (weight ["weight" ].shape , (32 , 8 ))
31
+ # new quant version weight
32
+ self .method .new_quant_version = True
33
+ weight = self .method .get_weight (8 , 32 , torch .bfloat16 )
34
+ self .assertEqual (weight ["weight" ].dtype , torch .int8 )
35
+ self .assertEqual (weight ["weight" ].shape , (16 , 8 ))
36
+ self .assertEqual (weight ["_packed_dim" ], 0 )
37
+ self .assertEqual (weight ["_packed_factor" ], 2 )
31
38
32
39
def test_get_pergroup_param (self ):
33
40
params = self .method .get_pergroup_param (8 , 32 , torch .bfloat16 )
@@ -39,6 +46,64 @@ def test_get_pergroup_param(self):
39
46
self .assertEqual (params ["weight_scale_second" ].shape , (32 , 1 ))
40
47
self .assertEqual (params ["weight_offset_second" ].dtype , torch .bfloat16 )
41
48
self .assertEqual (params ["weight_offset_second" ].shape , (32 , 1 ))
49
+ # new quant version weight
50
+ self .method .new_quant_version = True
51
+ params = self .method .get_pergroup_param (8 , 32 , torch .bfloat16 ,
52
+ layer_type = "column" )
53
+ self .assertEqual (params ["scale_bias" ].dtype , torch .float32 )
54
+ self .assertEqual (params ["scale_bias" ].shape , (32 , 1 ))
55
+ params = self .method .get_pergroup_param (8 , 32 , torch .bfloat16 ,
56
+ layer_type = "row" )
57
+ self .assertEqual (params ["scale_bias" ].dtype , torch .float32 )
58
+ self .assertEqual (params ["scale_bias" ].shape , (32 , 16 ))
59
+
60
+ @patch ('torch_npu.npu_convert_weight_to_int4pack' )
61
+ @patch ('torch.Tensor.npu' )
62
+ def test_process_weights_after_loading (self , mock_npu ,
63
+ mock_npu_convert_weight ):
64
+ mock_npu .side_effect = lambda : torch .zeros ((1 , 32 ), dtype = torch .float32 )
65
+ mock_npu_convert_weight .return_value = torch .zeros ((32 , 4 ),
66
+ dtype = torch .int32 )
67
+ # old quant version weight
68
+ layer = torch .nn .Module ()
69
+ layer .weight = torch .nn .Parameter (torch .zeros ((32 , 8 ), dtype = torch .int8 ),
70
+ requires_grad = False )
71
+ layer .weight_scale = torch .nn .Parameter (torch .ones ((32 , 1 ),
72
+ dtype = torch .float32 ),
73
+ requires_grad = False )
74
+ layer .weight_offset = torch .nn .Parameter (
75
+ torch .empty_like (layer .weight_scale .data ), requires_grad = False )
76
+ layer .weight_scale_second = torch .nn .Parameter (
77
+ torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
78
+ layer .weight_offset_second = torch .nn .Parameter (
79
+ torch .empty_like (layer .weight_scale_second .data ),
80
+ requires_grad = False )
81
+ self .method .process_weights_after_loading (layer )
82
+ self .assertTrue (hasattr (layer , "weight_scale_bias" ))
83
+ self .assertEqual (layer .weight_scale_bias .data .shape , (32 , ))
84
+ self .assertEqual (layer .weight_scale_bias .data .dtype , torch .float32 )
85
+ # new quant version weight
86
+ self .method .new_quant_version = True
87
+ new_layer = torch .nn .Module ()
88
+ new_layer .weight = torch .nn .Parameter (torch .zeros ((16 , 8 ),
89
+ dtype = torch .int8 ),
90
+ requires_grad = False )
91
+ new_layer .weight_scale = torch .nn .Parameter (
92
+ torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
93
+ new_layer .weight_offset = torch .nn .Parameter (
94
+ torch .empty_like (new_layer .weight_scale .data ), requires_grad = False )
95
+ new_layer .weight_scale_second = torch .nn .Parameter (
96
+ torch .ones ((32 , 1 ), dtype = torch .float32 ), requires_grad = False )
97
+ new_layer .weight_offset_second = torch .nn .Parameter (
98
+ torch .empty_like (new_layer .weight_scale_second .data ),
99
+ requires_grad = False )
100
+ new_layer .scale_bias = torch .nn .Parameter (torch .zeros ((32 , 1 ),
101
+ dtype = torch .float32 ),
102
+ requires_grad = False )
103
+ self .method .process_weights_after_loading (new_layer )
104
+ self .assertEqual (new_layer .scale_bias .data .shape , (32 , ))
105
+ self .assertTrue (hasattr (new_layer , "weight_scale_second" ))
106
+ self .assertEqual (new_layer .weight_scale_second .data .shape , (1 , 32 ))
42
107
43
108
44
109
class TestAscendW4A8DynamicFusedMoEMethod (TestBase ):
0 commit comments