@@ -36,18 +36,42 @@ class AscendW4A8DynamicLinearMethod:
36
36
37
37
def __init__ (self ):
38
38
self .transpose_weight = True
39
- try :
40
- self .group_size = get_current_vllm_config (
41
- ).quant_config .quant_description .get ("group_size" , 256 )
42
- except AttributeError :
43
- self .group_size = 256
39
+
40
+ vllm_config = get_current_vllm_config ()
41
+ self .group_size = vllm_config .quant_config .quant_description .get (
42
+ "group_size" , 256 )
43
+ quant_version = vllm_config .quant_config .quant_description .get (
44
+ "version" , "0" )
45
+ self .new_quant_version = quant_version == "1.0.0"
46
+
47
+ from vllm .distributed import get_tensor_model_parallel_world_size
48
+ self .tp_size = get_tensor_model_parallel_world_size ()
44
49
45
- @staticmethod
46
- def get_weight (input_size : int , output_size : int ,
50
+ def get_weight (self , input_size : int , output_size : int ,
47
51
params_dtype : torch .dtype ) -> Dict [str , Any ]:
48
- params_dict = {
49
- "weight" : torch .empty (output_size , input_size , dtype = torch .int8 )
50
- }
52
+ """Create weight parameters.
53
+
54
+ For new quantization version (double int4 pack into int8), the output dimension
55
+ is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
56
+ dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
57
+ """
58
+ params_dict = {}
59
+
60
+ if self .new_quant_version :
61
+ # double int4 pack into int8: output dimension is compressed
62
+ pack_factor = 2
63
+ actual_output_size = output_size // pack_factor
64
+ params_dict ["weight" ] = torch .empty (actual_output_size ,
65
+ input_size ,
66
+ dtype = torch .int8 )
67
+ # Add packing information for vLLM's weight_loader
68
+ params_dict ["_packed_dim" ] = 0
69
+ params_dict ["_packed_factor" ] = pack_factor
70
+ else :
71
+ params_dict ["weight" ] = torch .empty (output_size ,
72
+ input_size ,
73
+ dtype = torch .int8 )
74
+
51
75
return params_dict
52
76
53
77
@staticmethod
@@ -60,7 +84,17 @@ def get_perchannel_param(output_size: int,
60
84
return {}
61
85
62
86
def get_pergroup_param (self , input_size : int , output_size : int ,
63
- params_dtype : torch .dtype ) -> Dict [str , Any ]:
87
+ params_dtype : torch .dtype ,
88
+ layer_type : Optional [str ] = None ) -> Dict [str , Any ]:
89
+ """
90
+ Create per-group quantization parameters.
91
+
92
+ Args:
93
+ input_size: input dimension size
94
+ output_size: output dimension size
95
+ params_dtype: parameter data type
96
+ layer_type: layer type hint, can be "row" (down_proj/o_proj) or "column" (gate_up_proj/qkv_proj)
97
+ """
64
98
params_dict = {}
65
99
params_dict ["weight_scale" ] = torch .empty (output_size ,
66
100
1 ,
@@ -76,17 +110,60 @@ def get_pergroup_param(self, input_size: int, output_size: int,
76
110
input_size //
77
111
self .group_size ,
78
112
dtype = params_dtype )
113
+
114
+ # ✅ New quantization version includes scale_bias parameters
115
+ # Shape depends on layer type:
116
+ # - ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1]
117
+ # - RowParallel (down_proj, o_proj): [output_size, 16 // tp_size]
118
+ if self .new_quant_version :
119
+ if layer_type == "row" :
120
+ # RowParallel: down_proj, o_proj
121
+ # scale_bias shape: [output_size, 16 // tp_size]
122
+ scale_bias_dim = 16 // self .tp_size
123
+ else :
124
+ # ColumnParallel (default): gate_up_proj, qkv_proj
125
+ # scale_bias shape: [output_size, 1]
126
+ scale_bias_dim = 1
127
+
128
+ params_dict ["scale_bias" ] = torch .empty (output_size ,
129
+ scale_bias_dim ,
130
+ dtype = torch .float32 )
79
131
return params_dict
80
132
81
133
@staticmethod
82
- def process_scale_second (weight : torch .Tensor , scale : torch .Tensor ,
83
- per_group_scale : torch .Tensor ):
134
+ def process_scale_second (weight : torch .Tensor ,
135
+ scale : torch .Tensor ,
136
+ per_group_scale : torch .Tensor ,
137
+ is_new_quant : bool = False ):
138
+ """
139
+ Process the scale for second-level quantization.
140
+
141
+ Args:
142
+ weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
143
+ scale: first-level quantization scale [output_size]
144
+ per_group_scale: second-level per-group quantization scale [group_num, n_scale]
145
+ is_new_quant: whether it's the new quantization version (weight already compressed)
146
+
147
+ Returns:
148
+ (antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
149
+ """
84
150
k , n = weight .shape
85
- group_num , n = per_group_scale .shape
86
- weight_high = weight .to (torch .float32 ).reshape (
87
- group_num , - 1 , n ) * per_group_scale .reshape (group_num , 1 , n )
88
- weight_high = weight_high .reshape (k , n )
89
- bias = 8 * (weight_high .to (torch .float32 ) * scale ).sum (dim = 0 )
151
+ group_num , n_scale = per_group_scale .shape
152
+
153
+ # For new quantization version, the second dimension of weight is already compressed (double int4 pack into int8)
154
+ # Need to restore the logical dimension to correctly compute the scale
155
+ if is_new_quant :
156
+ n = n * 2
157
+
158
+ bias = None
159
+ if not is_new_quant :
160
+ weight_high = weight .to (torch .float32 ).reshape (
161
+ group_num , - 1 , n ) * per_group_scale .reshape (group_num , 1 , n )
162
+ weight_high = weight_high .reshape (k , n )
163
+ bias = 8 * (weight_high .to (torch .float32 ) * scale ).sum (dim = 0 )
164
+ # New version: scale_bias is not used currently
165
+ # because symmetric activation quantization is adopted in msIT for w4a8
166
+
90
167
antiquant_scale = (scale * per_group_scale ).reshape (group_num , n )
91
168
return antiquant_scale .npu (), bias
92
169
@@ -114,11 +191,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
114
191
layer .weight .data ,
115
192
layer .weight_scale .data ,
116
193
layer .weight_scale_second .data .transpose (0 , 1 ).contiguous (),
194
+ is_new_quant = self .new_quant_version ,
117
195
)
118
- param = torch .nn .Parameter (scale_bias , requires_grad = False )
119
- layer .register_parameter ("weight_scale_bias" , param )
120
- layer .weight .data = torch_npu .npu_convert_weight_to_int4pack (
121
- layer .weight .data .to (torch .int32 ))
196
+
197
+ # ✅ Handle scale_bias based on quantization version
198
+ if self .new_quant_version :
199
+ # New version: scale_bias is loaded from checkpoint
200
+ # Process the loaded data based on layer type
201
+ if hasattr (layer , "scale_bias" ):
202
+ # Detect layer type from shape
203
+ if layer .scale_bias .data .shape [1 ] == 1 :
204
+ # ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1] -> flatten
205
+ layer .scale_bias .data = layer .scale_bias .data .flatten ()
206
+ else :
207
+ # RowParallel (down_proj, o_proj): [output_size, 16//tp_size]
208
+ # Keep 2D shape but make contiguous
209
+ layer .scale_bias .data = layer .scale_bias .data .contiguous ()
210
+ else :
211
+ # Old version: scale_bias is computed, register as parameter
212
+ if scale_bias is not None :
213
+ param = torch .nn .Parameter (scale_bias , requires_grad = False )
214
+ layer .register_parameter ("weight_scale_bias" , param )
215
+
216
+ # Convert to NPU-specific int4pack format
217
+ if self .new_quant_version :
218
+ # New version: weights on disk are already in double int4 pack into int8 format
219
+ # Refer to MoE's pack_to_int32 method: use view(torch.int32) instead of npu_convert_weight_to_int4pack
220
+ # pack 4 int8(int4*2) to int32
221
+ assert layer .weight .data .shape [- 1 ] % 4 == 0 , \
222
+ f"the last dim of weight needs to be divided by 4, got shape { layer .weight .data .shape } "
223
+ layer .weight .data = layer .weight .data .view (
224
+ torch .int32 ).contiguous ()
225
+ else :
226
+ # Old version: weights are not compressed, need to be packed via npu_convert_weight_to_int4pack
227
+ layer .weight .data = torch_npu .npu_convert_weight_to_int4pack (
228
+ layer .weight .data .to (torch .int32 ))
122
229
123
230
124
231
class AscendW4A8DynamicFusedMoEMethod :
0 commit comments