10
10
)
11
11
from compressed_tensors .quantization .utils import is_fp4
12
12
from compressed_tensors .registry .registry import RegistryMixin
13
- from compressed_tensors .utils import safe_permute
13
+ from compressed_tensors .quantization . utils import strict_divide
14
14
from loguru import logger
15
15
from torch import FloatTensor , IntTensor , Tensor
16
16
@@ -125,8 +125,6 @@ def get_qparams(
125
125
:return: tuple of scale and zero point based on last observed value
126
126
"""
127
127
if observed is not None :
128
- group_size = self .quantization_args .group_size
129
-
130
128
if self .quantization_args .strategy == QuantizationStrategy .TENSOR :
131
129
# re-calculate scale and zero point, update the stored value
132
130
self ._scale , self ._zero_point = self .calculate_qparams (observed )
@@ -135,50 +133,43 @@ def get_qparams(
135
133
QuantizationStrategy .TENSOR_GROUP ,
136
134
QuantizationStrategy .GROUP ,
137
135
):
138
- rows = observed .shape [0 ]
139
- columns = observed .shape [1 ]
140
- num_groups = int (ceil (columns / group_size ))
141
- if num_groups * group_size != columns :
142
- logger .bind (log_once = True ).warning (
143
- "Attempting to quantize a module weight whose columns "
144
- f"({ columns } ) are not divisible by group_size ({ group_size } ). "
145
- "This scheme is not supported by vLLM, please consider "
146
- "adjusting the group_size for modules with this number of "
147
- "columns" ,
148
- )
136
+ # should be identical implementation to first half of
137
+ # `_process_quantization`
149
138
150
- self ._scale = torch .empty (
151
- (rows , num_groups ), dtype = observed .dtype , device = observed .device
152
- )
139
+ # get shapes
140
+ assert observed .ndim >= 2
141
+ rows , columns = observed .shape [- 2 :]
142
+ group_size = self .quantization_args .group_size
143
+ num_groups = strict_divide (columns , group_size )
144
+
145
+ # FP4: cast zp type
153
146
if is_fp4 (quantization_args = self .quantization_args ):
154
147
zp_dtype = FP8_E4M3_DATA .dtype
155
148
else :
156
149
zp_dtype = self .quantization_args .pytorch_dtype ()
157
150
151
+ # allocate qparams
152
+ self ._scale = torch .empty (
153
+ (rows , num_groups ), dtype = observed .dtype , device = observed .device
154
+ )
158
155
self ._zero_point = torch .empty (
159
156
(rows , num_groups ), dtype = zp_dtype , device = observed .device
160
157
)
161
158
162
- # support column-order (default) quantization as well as other orderings
163
- # such as activation ordering. Below checks if g_idx has initialized
164
- is_column_order = g_idx is None or - 1 in g_idx
165
- if is_column_order :
166
- group_sizes = torch .full ((num_groups ,), group_size , dtype = torch .int )
167
- else :
168
- group_indices , group_sizes = torch .unique (g_idx , return_counts = True )
169
- group_sizes = group_sizes [torch .argsort (group_indices )]
170
-
159
+ # permute groups
160
+ if g_idx is not None :
171
161
perm = torch .argsort (g_idx )
172
- observed = safe_permute ( observed , perm , dim = 1 )
162
+ observed = observed . index_select ( - 1 , perm )
173
163
174
164
# TODO: experiment with vectorizing for loop for performance
165
+ # all reduce all dims except the last one
175
166
end = 0
176
- for group_index , group_count in enumerate ( group_sizes ):
167
+ for group_index in range ( num_groups ):
177
168
start = end
178
- end = start + group_count
169
+ end = start + group_size
179
170
scale , zero_point = self .get_qparams_along_dim (
180
- observed [: , start :end ],
181
- 0 ,
171
+ observed [... , start :end ],
172
+ dim = tuple ( range ( observed . ndim - 1 )) ,
182
173
tensor_id = group_index ,
183
174
global_scale = global_scale ,
184
175
)
@@ -187,21 +178,23 @@ def get_qparams(
187
178
self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
188
179
189
180
elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
190
- # assume observed is transposed, because its the output, hence use dim 0
191
- self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , 0 )
181
+ # all reduce all dims except the last one
182
+ self ._scale , self ._zero_point = self .get_qparams_along_dim (
183
+ observed ,
184
+ dim = tuple (range (observed .ndim - 1 )),
185
+ )
192
186
193
187
elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
194
- # use dim 1, assume the obsersed.shape = [batch, token, hidden]
195
- # should be batch, token
188
+ # all reduce all dims except the last one
196
189
self ._scale , self ._zero_point = self .get_qparams_along_dim (
197
190
observed ,
198
- dim = { 0 , 1 } ,
191
+ dim = tuple ( range ( observed . ndim - 1 )) ,
199
192
)
200
193
201
194
elif self .quantization_args .strategy == QuantizationStrategy .BLOCK :
202
195
# Block-wise quantization: one scale/zero_point per block of shape
203
196
# [block_rows, block_cols]
204
- rows , cols = observed .shape [: 2 ]
197
+ rows , cols = observed .shape [- 2 : ]
205
198
bs = self .quantization_args .block_structure
206
199
if not (
207
200
isinstance (bs , (list , tuple ))
0 commit comments