Skip to content

Commit 36d8fc2

Browse files
committed
docs(quant): simplify and clarify comments
1 parent fa57f46 commit 36d8fc2

File tree

1 file changed

+9
-24
lines changed

1 file changed

+9
-24
lines changed

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,6 @@ def get_pergroup_param(self, input_size: int, output_size: int,
8888
layer_type: Optional[str] = None) -> Dict[str, Any]:
8989
"""
9090
Create per-group quantization parameters.
91-
92-
Args:
93-
input_size: input dimension size
94-
output_size: output dimension size
95-
params_dtype: parameter data type
96-
layer_type: "row" or "others" (default)
97-
- "row": RowParallelLinear (down_proj, o_proj)
98-
- "others": Others (ColumnParallel, ReplicatedLinear, etc.)
9991
"""
10092
params_dict = {}
10193
params_dict["weight_scale"] = torch.empty(output_size,
@@ -114,8 +106,8 @@ def get_pergroup_param(self, input_size: int, output_size: int,
114106
dtype=params_dtype)
115107

116108
# NOTE: In w4a8 quantization implementation,
117-
# for down_proj and o_proj scale_bias shape is [output_size, 16],
118-
# others are [output_size, 1]
109+
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
110+
# others are [output_size, 1]
119111
if self.new_quant_version:
120112
scale_bias_dim = 16 if layer_type == "row" else 1
121113

@@ -144,9 +136,8 @@ def process_scale_second(weight: torch.Tensor,
144136
k, n = weight.shape
145137
group_num, n_scale = per_group_scale.shape
146138

147-
# For new quantization version, the second dimension of weight is already compressed (double int4 pack into int8)
148-
# Need to restore the logical dimension to correctly compute the scale
149139
if is_new_quant:
140+
# Restore logical dimension for compressed weight
150141
n = n * 2
151142

152143
bias = None
@@ -155,9 +146,10 @@ def process_scale_second(weight: torch.Tensor,
155146
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
156147
weight_high = weight_high.reshape(k, n)
157148
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
158-
# New version: scale_bias is not used currently
159-
# because symmetric activation quantization is adopted in msIT for w4a8
149+
# NOTE: scale_bias is not used currently
150+
# because in msmodelslim w4a8 uses symmetric quantization
160151

152+
# TODO: support potential future asymmetric quantization
161153
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
162154
return antiquant_scale.npu(), bias
163155

@@ -188,36 +180,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
188180
is_new_quant=self.new_quant_version,
189181
)
190182

191-
# ✅ Handle scale_bias based on quantization version
192183
if self.new_quant_version:
193-
# New version: scale_bias is loaded from checkpoint
194184
# Process the loaded data based on layer type
195185
if hasattr(layer, "scale_bias"):
196-
# Detect layer type from shape
197186
if layer.scale_bias.data.shape[1] == 1:
198-
# ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1] -> flatten
199187
layer.scale_bias.data = layer.scale_bias.data.flatten()
200188
else:
201-
# RowParallel (down_proj, o_proj): [output_size, 16//tp_size]
202-
# Keep 2D shape but make contiguous
203189
layer.scale_bias.data = layer.scale_bias.data.contiguous()
204190
else:
205-
# Old version: scale_bias is computed, register as parameter
206191
if scale_bias is not None:
207192
param = torch.nn.Parameter(scale_bias, requires_grad=False)
208193
layer.register_parameter("weight_scale_bias", param)
209194

210195
# Convert to NPU-specific int4pack format
211196
if self.new_quant_version:
212-
# New version: weights on disk are already in double int4 pack into int8 format
213-
# Refer to MoE's pack_to_int32 method: use view(torch.int32) instead of npu_convert_weight_to_int4pack
197+
# weights on disk are already in packed int4 format
214198
# pack 4 int8(int4*2) to int32
215199
assert layer.weight.data.shape[-1] % 4 == 0, \
216200
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
217201
layer.weight.data = layer.weight.data.view(
218202
torch.int32).contiguous()
219203
else:
220-
# Old version: weights are not compressed, need to be packed via npu_convert_weight_to_int4pack
204+
# weights are not compressed
205+
# need to be packed via npu_convert_weight_to_int4pack
221206
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
222207
layer.weight.data.to(torch.int32))
223208

0 commit comments

Comments
 (0)