Skip to content

Commit a92e2be

Browse files
committed
fix(quant): fix logic for scale_bias dim
1 parent a782971 commit a92e2be

File tree

3 files changed

+12
-25
lines changed

3 files changed

+12
-25
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,10 @@ def create_weights(
288288
layer.register_parameter(perchannel_name, param)
289289
set_weight_attrs(param, extra_weight_attrs)
290290

291-
layer_type = None
292-
if isinstance(layer, RowParallelLinear):
293-
# down_proj, o_proj
294-
layer_type = "row"
295-
else:
296-
# gate_up_proj, qkv_proj (ColumnParallel or MergedColumnParallel)
297-
layer_type = "column"
291+
# NOTE: In w4a8 quantization implementation,
292+
# for down_proj and o_proj scale_bias shape is [output_size, 16],
293+
# others are [output_size, 1]
294+
layer_type = "row" if isinstance(layer, RowParallelLinear) else "others"
298295

299296
pergroup_dict = self.quant_method.get_pergroup_param(
300297
input_size_per_partition, output_size_per_partition, params_dtype,

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def get_pergroup_param(self, input_size: int, output_size: int,
9393
input_size: input dimension size
9494
output_size: output dimension size
9595
params_dtype: parameter data type
96-
layer_type: layer type hint, can be "row" (down_proj/o_proj) or "column" (gate_up_proj/qkv_proj)
96+
layer_type: "row" or "others" (default)
97+
- "row": RowParallelLinear (down_proj, o_proj)
98+
- "others": Others (ColumnParallel, ReplicatedLinear, etc.)
9799
"""
98100
params_dict = {}
99101
params_dict["weight_scale"] = torch.empty(output_size,
@@ -111,19 +113,11 @@ def get_pergroup_param(self, input_size: int, output_size: int,
111113
self.group_size,
112114
dtype=params_dtype)
113115

114-
# ✅ New quantization version includes scale_bias parameters
115-
# Shape depends on layer type:
116-
# - ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1]
117-
# - RowParallel (down_proj, o_proj): [output_size, 16 // tp_size]
116+
# NOTE: In w4a8 quantization implementation,
117+
# for down_proj and o_proj scale_bias shape is [output_size, 16],
118+
# others are [output_size, 1]
118119
if self.new_quant_version:
119-
if layer_type == "row":
120-
# RowParallel: down_proj, o_proj
121-
# scale_bias shape: [output_size, 16 // tp_size]
122-
scale_bias_dim = 16 // self.tp_size
123-
else:
124-
# ColumnParallel (default): gate_up_proj, qkv_proj
125-
# scale_bias shape: [output_size, 1]
126-
scale_bias_dim = 1
120+
scale_bias_dim = 16 if layer_type == "row" else 1
127121

128122
params_dict["scale_bias"] = torch.empty(output_size,
129123
scale_bias_dim,

vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,7 @@ def get_pergroup_param(self, input_size: int, output_size: int,
9898
dtype=params_dtype)
9999

100100
if self.new_quant_version:
101-
if layer_type == "row":
102-
scale_bias_dim = 16 // self.tp_size
103-
else:
104-
scale_bias_dim = 1
105-
101+
scale_bias_dim = 16 if layer_type == "row" else 1
106102
params_dict["scale_bias"] = torch.empty(output_size,
107103
scale_bias_dim,
108104
dtype=torch.float32)

0 commit comments

Comments
 (0)