@@ -93,7 +93,9 @@ def get_pergroup_param(self, input_size: int, output_size: int,
93
93
input_size: input dimension size
94
94
output_size: output dimension size
95
95
params_dtype: parameter data type
96
- layer_type: layer type hint, can be "row" (down_proj/o_proj) or "column" (gate_up_proj/qkv_proj)
96
+ layer_type: "row" or "others" (default)
97
+ - "row": RowParallelLinear (down_proj, o_proj)
98
+ - "others": Others (ColumnParallel, ReplicatedLinear, etc.)
97
99
"""
98
100
params_dict = {}
99
101
params_dict ["weight_scale" ] = torch .empty (output_size ,
@@ -111,19 +113,11 @@ def get_pergroup_param(self, input_size: int, output_size: int,
111
113
self .group_size ,
112
114
dtype = params_dtype )
113
115
114
- # ✅ New quantization version includes scale_bias parameters
115
- # Shape depends on layer type:
116
- # - ColumnParallel (gate_up_proj, qkv_proj): [output_size, 1]
117
- # - RowParallel (down_proj, o_proj): [output_size, 16 // tp_size]
116
+ # NOTE: In w4a8 quantization implementation,
117
+ # for down_proj and o_proj scale_bias shape is [output_size, 16],
118
+ # others are [output_size, 1]
118
119
if self .new_quant_version :
119
- if layer_type == "row" :
120
- # RowParallel: down_proj, o_proj
121
- # scale_bias shape: [output_size, 16 // tp_size]
122
- scale_bias_dim = 16 // self .tp_size
123
- else :
124
- # ColumnParallel (default): gate_up_proj, qkv_proj
125
- # scale_bias shape: [output_size, 1]
126
- scale_bias_dim = 1
120
+ scale_bias_dim = 16 if layer_type == "row" else 1
127
121
128
122
params_dict ["scale_bias" ] = torch .empty (output_size ,
129
123
scale_bias_dim ,
0 commit comments