Skip to content

Commit 55a7718

Browse files
committed
feat(quant): Implement new version w4a8 dynamic quantization logic for AscendW4A8DynamicLinearMethod
Signed-off-by: Anionex <[email protected]>
1 parent 474fa73 commit 55a7718

File tree

2 files changed

+101
-18
lines changed

2 files changed

+101
-18
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,22 @@ def create_weights(
254254
weight_dict = self.quant_method.get_weight(input_size_per_partition,
255255
output_size_per_partition,
256256
params_dtype)
257+
258+
# Extract packing information (if present)
259+
packed_dim = weight_dict.pop("_packed_dim", None)
260+
packed_factor = weight_dict.pop("_packed_factor", None)
261+
257262
for weight_name, weight_param in weight_dict.items():
258263
param = torch.nn.Parameter(weight_param, requires_grad=False)
259264
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
265+
266+
# Set packing attributes if the weight is packed
267+
if packed_dim is not None and packed_factor is not None:
268+
set_weight_attrs(param, {
269+
"packed_dim": packed_dim,
270+
"packed_factor": packed_factor
271+
})
272+
260273
layer.register_parameter(weight_name, param)
261274
set_weight_attrs(param, extra_weight_attrs)
262275

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,41 @@ class AscendW4A8DynamicLinearMethod:
3737
def __init__(self):
3838
self.transpose_weight = True
3939
try:
40-
self.group_size = get_current_vllm_config(
41-
).quant_config.quant_description.get("group_size", 256)
40+
vllm_config = get_current_vllm_config()
41+
self.group_size = vllm_config.quant_config.quant_description.get(
42+
"group_size", 256)
43+
quant_version = vllm_config.quant_config.quant_description.get(
44+
"version", "0")
45+
self.new_quant_version = quant_version == "1.0.0"
4246
except AttributeError:
4347
self.group_size = 256
48+
self.new_quant_version = False
4449

45-
@staticmethod
46-
def get_weight(input_size: int, output_size: int,
50+
def get_weight(self, input_size: int, output_size: int,
4751
params_dtype: torch.dtype) -> Dict[str, Any]:
48-
params_dict = {
49-
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
50-
}
52+
"""Create weight parameters.
53+
54+
For new quantization version (double int4 pack into int8), the output dimension
55+
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
56+
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
57+
"""
58+
params_dict = {}
59+
60+
if self.new_quant_version:
61+
# double int4 pack into int8: output dimension is compressed
62+
pack_factor = 2
63+
actual_output_size = output_size // pack_factor
64+
params_dict["weight"] = torch.empty(actual_output_size,
65+
input_size,
66+
dtype=torch.int8)
67+
# Add packing information for vLLM's weight_loader
68+
params_dict["_packed_dim"] = 0
69+
params_dict["_packed_factor"] = pack_factor
70+
else:
71+
params_dict["weight"] = torch.empty(output_size,
72+
input_size,
73+
dtype=torch.int8)
74+
5175
return params_dict
5276

5377
@staticmethod
@@ -76,17 +100,48 @@ def get_pergroup_param(self, input_size: int, output_size: int,
76100
input_size //
77101
self.group_size,
78102
dtype=params_dtype)
103+
104+
# New quantization version includes scale_bias parameters
105+
if self.new_quant_version:
106+
params_dict["scale_bias"] = torch.empty(output_size,
107+
1,
108+
dtype=torch.float32)
79109
return params_dict
80110

81111
@staticmethod
82-
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
83-
per_group_scale: torch.Tensor):
112+
def process_scale_second(weight: torch.Tensor,
113+
scale: torch.Tensor,
114+
per_group_scale: torch.Tensor,
115+
is_new_quant: bool = False):
116+
"""
117+
Process the scale for second-level quantization.
118+
119+
Args:
120+
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
121+
scale: first-level quantization scale [output_size]
122+
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
123+
is_new_quant: whether it's the new quantization version (weight already compressed)
124+
125+
Returns:
126+
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
127+
"""
84128
k, n = weight.shape
85-
group_num, n = per_group_scale.shape
86-
weight_high = weight.to(torch.float32).reshape(
87-
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
88-
weight_high = weight_high.reshape(k, n)
89-
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
129+
group_num, n_scale = per_group_scale.shape
130+
131+
# For new quantization version, the second dimension of weight is already compressed (double int4 pack into int8)
132+
# Need to restore the logical dimension to correctly compute the scale
133+
if is_new_quant:
134+
n = n * 2
135+
136+
bias = None
137+
if not is_new_quant:
138+
weight_high = weight.to(torch.float32).reshape(
139+
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
140+
weight_high = weight_high.reshape(k, n)
141+
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
142+
# New version: scale_bias is not used currently
143+
# because symmetric activation quantization is adopted in msIT for w4a8
144+
90145
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
91146
return antiquant_scale.npu(), bias
92147

@@ -114,11 +169,26 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
114169
layer.weight.data,
115170
layer.weight_scale.data,
116171
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
172+
is_new_quant=self.new_quant_version,
117173
)
118-
param = torch.nn.Parameter(scale_bias, requires_grad=False)
119-
layer.register_parameter("weight_scale_bias", param)
120-
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
121-
layer.weight.data.to(torch.int32))
174+
175+
if not self.new_quant_version and scale_bias is not None:
176+
param = torch.nn.Parameter(scale_bias, requires_grad=False)
177+
layer.register_parameter("weight_scale_bias", param)
178+
179+
# Convert to NPU-specific int4pack format
180+
if self.new_quant_version:
181+
# New version: weights on disk are already in double int4 pack into int8 format
182+
# Refer to MoE's pack_to_int32 method: use view(torch.int32) instead of npu_convert_weight_to_int4pack
183+
# pack 4 int8(int4*2) to int32
184+
assert layer.weight.data.shape[-1] % 4 == 0, \
185+
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
186+
layer.weight.data = layer.weight.data.view(
187+
torch.int32).contiguous()
188+
else:
189+
# Old version: weights are not compressed, need to be packed via npu_convert_weight_to_int4pack
190+
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
191+
layer.weight.data.to(torch.int32))
122192

123193

124194
class AscendW4A8DynamicFusedMoEMethod:

0 commit comments

Comments
 (0)