Skip to content

Commit b10e23a

Browse files
committed
feat(quant): support w4a8 dynamic quantization v1.0.0 for linear layers
Signed-off-by: Anionex <[email protected]>
1 parent 474fa73 commit b10e23a

File tree

3 files changed

+226
-45
lines changed

3 files changed

+226
-45
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,22 @@ def create_weights(
254254
weight_dict = self.quant_method.get_weight(input_size_per_partition,
255255
output_size_per_partition,
256256
params_dtype)
257+
258+
# Extract packing information (if present)
259+
packed_dim = weight_dict.pop("_packed_dim", None)
260+
packed_factor = weight_dict.pop("_packed_factor", None)
261+
257262
for weight_name, weight_param in weight_dict.items():
258263
param = torch.nn.Parameter(weight_param, requires_grad=False)
259264
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
265+
266+
# Set packing attributes if the weight is packed
267+
if packed_dim is not None and packed_factor is not None:
268+
set_weight_attrs(param, {
269+
"packed_dim": packed_dim,
270+
"packed_factor": packed_factor
271+
})
272+
260273
layer.register_parameter(weight_name, param)
261274
set_weight_attrs(param, extra_weight_attrs)
262275

@@ -275,8 +288,17 @@ def create_weights(
275288
layer.register_parameter(perchannel_name, param)
276289
set_weight_attrs(param, extra_weight_attrs)
277290

291+
layer_type = None
292+
if isinstance(layer, RowParallelLinear):
293+
# down_proj, o_proj
294+
layer_type = "row"
295+
else:
296+
# gate_up_proj, qkv_proj (ColumnParallel or MergedColumnParallel)
297+
layer_type = "column"
298+
278299
pergroup_dict = self.quant_method.get_pergroup_param(
279-
input_size_per_partition, output_size_per_partition, params_dtype)
300+
input_size_per_partition, output_size_per_partition, params_dtype,
301+
layer_type=layer_type)
280302
for pergroup_name, pergroup_param in pergroup_dict.items():
281303
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
282304
set_weight_attrs(param, {"output_dim": 0})

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 129 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,42 @@ class AscendW4A8DynamicLinearMethod:
3636

3737
def __init__(self):
3838
self.transpose_weight = True
39-
try:
40-
self.group_size = get_current_vllm_config(
41-
).quant_config.quant_description.get("group_size", 256)
42-
except AttributeError:
43-
self.group_size = 256
39+
40+
vllm_config = get_current_vllm_config()
41+
self.group_size = vllm_config.quant_config.quant_description.get(
42+
"group_size", 256)
43+
quant_version = vllm_config.quant_config.quant_description.get(
44+
"version", "0")
45+
self.new_quant_version = quant_version == "1.0.0"
46+
47+
from vllm.distributed import get_tensor_model_parallel_world_size
48+
self.tp_size = get_tensor_model_parallel_world_size()
4449

45-
@staticmethod
46-
def get_weight(input_size: int, output_size: int,
50+
def get_weight(self, input_size: int, output_size: int,
4751
params_dtype: torch.dtype) -> Dict[str, Any]:
48-
params_dict = {
49-
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
50-
}
52+
"""Create weight parameters.
53+
54+
For new quantization version (double int4 pack into int8), the output dimension
55+
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
56+
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
57+
"""
58+
params_dict = {}
59+
60+
if self.new_quant_version:
61+
# double int4 pack into int8: output dimension is compressed
62+
pack_factor = 2
63+
actual_output_size = output_size // pack_factor
64+
params_dict["weight"] = torch.empty(actual_output_size,
65+
input_size,
66+
dtype=torch.int8)
67+
# Add packing information for vLLM's weight_loader
68+
params_dict["_packed_dim"] = 0
69+
params_dict["_packed_factor"] = pack_factor
70+
else:
71+
params_dict["weight"] = torch.empty(output_size,
72+
input_size,
73+
dtype=torch.int8)
74+
5175
return params_dict
5276

5377
@staticmethod
@@ -60,7 +84,17 @@ def get_perchannel_param(output_size: int,
6084
return {}
6185

6286
def get_pergroup_param(self, input_size: int, output_size: int,
63-
params_dtype: torch.dtype) -> Dict[str, Any]:
87+
params_dtype: torch.dtype,
88+
layer_type: Optional[str] = None) -> Dict[str, Any]:
89+
"""
90+
Create per-group quantization parameters.
91+
92+
Args:
93+
input_size: input dimension size
94+
output_size: output dimension size
95+
params_dtype: parameter data type
96+
layer_type: layer type hint, can be "row" (down_proj/o_proj) or "column" (gate_up_proj/qkv_proj)
97+
"""
6498
params_dict = {}
6599
params_dict["weight_scale"] = torch.empty(output_size,
66100
1,
@@ -76,17 +110,60 @@ def get_pergroup_param(self, input_size: int, output_size: int,
76110
input_size //
77111
self.group_size,
78112
dtype=params_dtype)
113+
114+
# ✅ New quantization version includes scale_bias parameters
115+
# Shape depends on layer type:
116+
# - ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1]
117+
# - RowParallel (down_proj, o_proj): [output_size, 16 // tp_size]
118+
if self.new_quant_version:
119+
if layer_type == "row":
120+
# RowParallel: down_proj, o_proj
121+
# scale_bias shape: [output_size, 16 // tp_size]
122+
scale_bias_dim = 16 // self.tp_size
123+
else:
124+
# ColumnParallel (default): gate_up_proj, qkv_proj
125+
# scale_bias shape: [output_size, 1]
126+
scale_bias_dim = 1
127+
128+
params_dict["scale_bias"] = torch.empty(output_size,
129+
scale_bias_dim,
130+
dtype=torch.float32)
79131
return params_dict
80132

81133
@staticmethod
82-
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
83-
per_group_scale: torch.Tensor):
134+
def process_scale_second(weight: torch.Tensor,
135+
scale: torch.Tensor,
136+
per_group_scale: torch.Tensor,
137+
is_new_quant: bool = False):
138+
"""
139+
Process the scale for second-level quantization.
140+
141+
Args:
142+
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
143+
scale: first-level quantization scale [output_size]
144+
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
145+
is_new_quant: whether it's the new quantization version (weight already compressed)
146+
147+
Returns:
148+
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
149+
"""
84150
k, n = weight.shape
85-
group_num, n = per_group_scale.shape
86-
weight_high = weight.to(torch.float32).reshape(
87-
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
88-
weight_high = weight_high.reshape(k, n)
89-
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
151+
group_num, n_scale = per_group_scale.shape
152+
153+
# For new quantization version, the second dimension of weight is already compressed (double int4 pack into int8)
154+
# Need to restore the logical dimension to correctly compute the scale
155+
if is_new_quant:
156+
n = n * 2
157+
158+
bias = None
159+
if not is_new_quant:
160+
weight_high = weight.to(torch.float32).reshape(
161+
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
162+
weight_high = weight_high.reshape(k, n)
163+
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
164+
# New version: scale_bias is not used currently
165+
# because symmetric activation quantization is adopted in msIT for w4a8
166+
90167
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
91168
return antiquant_scale.npu(), bias
92169

@@ -114,11 +191,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
114191
layer.weight.data,
115192
layer.weight_scale.data,
116193
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
194+
is_new_quant=self.new_quant_version,
117195
)
118-
param = torch.nn.Parameter(scale_bias, requires_grad=False)
119-
layer.register_parameter("weight_scale_bias", param)
120-
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
121-
layer.weight.data.to(torch.int32))
196+
197+
# ✅ Handle scale_bias based on quantization version
198+
if self.new_quant_version:
199+
# New version: scale_bias is loaded from checkpoint
200+
# Process the loaded data based on layer type
201+
if hasattr(layer, "scale_bias"):
202+
# Detect layer type from shape
203+
if layer.scale_bias.data.shape[1] == 1:
204+
# ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1] -> flatten
205+
layer.scale_bias.data = layer.scale_bias.data.flatten()
206+
else:
207+
# RowParallel (down_proj, o_proj): [output_size, 16//tp_size]
208+
# Keep 2D shape but make contiguous
209+
layer.scale_bias.data = layer.scale_bias.data.contiguous()
210+
else:
211+
# Old version: scale_bias is computed, register as parameter
212+
if scale_bias is not None:
213+
param = torch.nn.Parameter(scale_bias, requires_grad=False)
214+
layer.register_parameter("weight_scale_bias", param)
215+
216+
# Convert to NPU-specific int4pack format
217+
if self.new_quant_version:
218+
# New version: weights on disk are already in double int4 pack into int8 format
219+
# Refer to MoE's pack_to_int32 method: use view(torch.int32) instead of npu_convert_weight_to_int4pack
220+
# pack 4 int8(int4*2) to int32
221+
assert layer.weight.data.shape[-1] % 4 == 0, \
222+
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
223+
layer.weight.data = layer.weight.data.view(
224+
torch.int32).contiguous()
225+
else:
226+
# Old version: weights are not compressed, need to be packed via npu_convert_weight_to_int4pack
227+
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
228+
layer.weight.data.to(torch.int32))
122229

123230

124231
class AscendW4A8DynamicFusedMoEMethod:

vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,34 @@ class TorchairAscendW4A8DynamicLinearMethod:
3939

4040
def __init__(self):
4141
self.transpose_weight = True
42-
try:
43-
self.group_size = get_current_vllm_config(
44-
).quant_config.quant_description.get("group_size", 256)
45-
except AttributeError:
46-
self.group_size = 256
42+
43+
vllm_config = get_current_vllm_config()
44+
self.group_size = vllm_config.quant_config.quant_description.get(
45+
"group_size", 256)
46+
quant_version = vllm_config.quant_config.quant_description.get(
47+
"version", "0")
48+
self.new_quant_version = quant_version == "1.0.0"
49+
50+
from vllm.distributed import get_tensor_model_parallel_world_size
51+
self.tp_size = get_tensor_model_parallel_world_size()
4752

48-
@staticmethod
49-
def get_weight(input_size: int, output_size: int,
53+
def get_weight(self, input_size: int, output_size: int,
5054
params_dtype: torch.dtype) -> Dict[str, Any]:
51-
params_dict = {
52-
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
53-
}
55+
params_dict = {}
56+
57+
if self.new_quant_version:
58+
pack_factor = 2
59+
actual_output_size = output_size // pack_factor
60+
params_dict["weight"] = torch.empty(actual_output_size,
61+
input_size,
62+
dtype=torch.int8)
63+
params_dict["_packed_dim"] = 0
64+
params_dict["_packed_factor"] = pack_factor
65+
else:
66+
params_dict["weight"] = torch.empty(output_size,
67+
input_size,
68+
dtype=torch.int8)
69+
5470
return params_dict
5571

5672
@staticmethod
@@ -63,7 +79,8 @@ def get_perchannel_param(output_size: int,
6379
return {}
6480

6581
def get_pergroup_param(self, input_size: int, output_size: int,
66-
params_dtype: torch.dtype) -> Dict[str, Any]:
82+
params_dtype: torch.dtype,
83+
layer_type: Optional[str] = None) -> Dict[str, Any]:
6784
params_dict = {}
6885
params_dict["weight_scale"] = torch.empty(output_size,
6986
1,
@@ -79,17 +96,36 @@ def get_pergroup_param(self, input_size: int, output_size: int,
7996
input_size //
8097
self.group_size,
8198
dtype=params_dtype)
99+
100+
if self.new_quant_version:
101+
if layer_type == "row":
102+
scale_bias_dim = 16 // self.tp_size
103+
else:
104+
scale_bias_dim = 1
105+
106+
params_dict["scale_bias"] = torch.empty(output_size,
107+
scale_bias_dim,
108+
dtype=torch.float32)
82109
return params_dict
83110

84111
@staticmethod
85-
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
86-
per_group_scale: torch.Tensor):
112+
def process_scale_second(weight: torch.Tensor,
113+
scale: torch.Tensor,
114+
per_group_scale: torch.Tensor,
115+
is_new_quant: bool = False):
87116
k, n = weight.shape
88-
group_num, n = per_group_scale.shape
89-
weight_high = weight.to(torch.float32).reshape(
90-
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
91-
weight_high = weight_high.reshape(k, n)
92-
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
117+
group_num, n_scale = per_group_scale.shape
118+
119+
if is_new_quant:
120+
n = n * 2
121+
122+
bias = None
123+
if not is_new_quant:
124+
weight_high = weight.to(torch.float32).reshape(
125+
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
126+
weight_high = weight_high.reshape(k, n)
127+
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
128+
93129
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
94130
return antiquant_scale.npu(), bias
95131

@@ -117,11 +153,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
117153
layer.weight.data,
118154
layer.weight_scale.data,
119155
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
156+
is_new_quant=self.new_quant_version,
120157
)
121-
param = torch.nn.Parameter(scale_bias, requires_grad=False)
122-
layer.register_parameter("weight_scale_bias", param)
123-
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
124-
layer.weight.data.to(torch.int32))
158+
159+
if self.new_quant_version:
160+
if hasattr(layer, "scale_bias"):
161+
if layer.scale_bias.data.shape[1] == 1:
162+
layer.scale_bias.data = layer.scale_bias.data.flatten()
163+
else:
164+
layer.scale_bias.data = layer.scale_bias.data.contiguous()
165+
else:
166+
if scale_bias is not None:
167+
param = torch.nn.Parameter(scale_bias, requires_grad=False)
168+
layer.register_parameter("weight_scale_bias", param)
169+
170+
if self.new_quant_version:
171+
assert layer.weight.data.shape[-1] % 4 == 0, \
172+
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
173+
layer.weight.data = layer.weight.data.view(torch.int32).contiguous()
174+
else:
175+
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
176+
layer.weight.data.to(torch.int32))
125177

126178

127179
class TorchairAscendW4A8DynamicFusedMoEMethod:

0 commit comments

Comments
 (0)