8
8
QuantizationArgs ,
9
9
QuantizationStrategy ,
10
10
)
11
- from compressed_tensors .quantization .utils import is_fp4
11
+ from compressed_tensors .quantization .utils import is_fp4 , strict_divide
12
12
from compressed_tensors .registry .registry import RegistryMixin
13
13
from loguru import logger
14
14
from torch import FloatTensor , IntTensor , Tensor
@@ -128,8 +128,6 @@ def get_qparams(
128
128
:return: tuple of scale and zero point based on last observed value
129
129
"""
130
130
if observed is not None :
131
- group_size = self .quantization_args .group_size
132
-
133
131
if self .quantization_args .strategy == QuantizationStrategy .TENSOR :
134
132
# re-calculate scale and zero point, update the stored value
135
133
self ._scale , self ._zero_point = self .calculate_qparams (observed )
@@ -138,49 +136,43 @@ def get_qparams(
138
136
QuantizationStrategy .TENSOR_GROUP ,
139
137
QuantizationStrategy .GROUP ,
140
138
):
141
- rows = observed .shape [0 ]
142
- columns = observed .shape [1 ]
143
- num_groups = int (ceil (columns / group_size ))
144
- if num_groups * group_size != columns :
145
- logger .bind (log_once = True ).warning (
146
- "Attempting to quantize a module weight whose columns "
147
- f"({ columns } ) are not divisible by group_size ({ group_size } ). "
148
- "This scheme is not supported by vLLM, please consider "
149
- "adjusting the group_size for modules with this number of "
150
- "columns" ,
151
- )
139
+ # should be identical implementation to first half of
140
+ # `_process_quantization`
152
141
153
- self ._scale = torch .empty (
154
- (rows , num_groups ), dtype = observed .dtype , device = observed .device
155
- )
142
+ # get shapes
143
+ assert observed .ndim >= 2
144
+ rows , columns = observed .shape [- 2 :]
145
+ group_size = self .quantization_args .group_size
146
+ num_groups = strict_divide (columns , group_size )
147
+
148
+ # FP4: cast zp type
156
149
if is_fp4 (quantization_args = self .quantization_args ):
157
150
zp_dtype = FP8_E4M3_DATA .dtype
158
151
else :
159
152
zp_dtype = self .quantization_args .pytorch_dtype ()
160
153
154
+ # allocate qparams
155
+ self ._scale = torch .empty (
156
+ (rows , num_groups ), dtype = observed .dtype , device = observed .device
157
+ )
161
158
self ._zero_point = torch .empty (
162
159
(rows , num_groups ), dtype = zp_dtype , device = observed .device
163
160
)
164
161
165
- # support column-order (default) quantization as well as other orderings
166
- # such as activation ordering. Below checks if g_idx has initialized
167
- is_column_order = g_idx is None or - 1 in g_idx
168
- if is_column_order :
169
- group_sizes = torch .full ((num_groups ,), group_size , dtype = torch .int )
170
- else :
171
- group_indices , group_sizes = torch .unique (g_idx , return_counts = True )
172
- group_sizes = group_sizes [torch .argsort (group_indices )]
173
-
174
- observed = observed .index_select (- 1 , g_idx )
162
+ # permute groups
163
+ if g_idx is not None :
164
+ perm = torch .argsort (g_idx )
165
+ observed = observed .index_select (- 1 , perm )
175
166
176
167
# TODO: experiment with vectorizing for loop for performance
168
+ # all reduce all dims except the second to last one
177
169
end = 0
178
- for group_index , group_count in enumerate ( group_sizes ):
170
+ for group_index in range ( num_groups ):
179
171
start = end
180
- end = start + group_count
172
+ end = start + group_size
181
173
scale , zero_point = self .get_qparams_along_dim (
182
- observed [: , start :end ],
183
- 0 ,
174
+ observed [... , start :end ],
175
+ dim = - 2 ,
184
176
tensor_id = group_index ,
185
177
global_scale = global_scale ,
186
178
)
@@ -189,8 +181,8 @@ def get_qparams(
189
181
self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
190
182
191
183
elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
192
- # assume observed is transposed, because its the output, hence use dim 0
193
- self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , 0 )
184
+ # all reduce all dims except the second to last one
185
+ self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , - 2 )
194
186
195
187
elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
196
188
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
@@ -203,7 +195,7 @@ def get_qparams(
203
195
elif self .quantization_args .strategy == QuantizationStrategy .BLOCK :
204
196
# Block-wise quantization: one scale/zero_point per block of shape
205
197
# [block_rows, block_cols]
206
- rows , cols = observed .shape [: 2 ]
198
+ rows , cols = observed .shape [- 2 : ]
207
199
bs = self .quantization_args .block_structure
208
200
if not (
209
201
isinstance (bs , (list , tuple ))
@@ -255,15 +247,20 @@ def get_qparams(
255
247
256
248
def get_qparams_along_dim (
257
249
self ,
258
- observed ,
250
+ observed : torch . Tensor ,
259
251
dim : Union [int , Iterable [int ]],
260
252
tensor_id : Optional [Any ] = None ,
261
253
global_scale : Optional [Tensor ] = None ,
262
254
):
255
+ # cast to set
263
256
if isinstance (dim , int ):
264
257
dim = [dim ]
265
258
dim = set (dim )
266
259
260
+ # convert negative dims
261
+ dim = [d if d >= 0 else observed .ndim + d for d in dim ]
262
+
263
+ # reduce all dimensions except the the one passed as argument to this function
267
264
reduce_dims = tuple (idx for idx in range (observed .ndim ) if idx not in dim )
268
265
return self .calculate_qparams (
269
266
observed ,
0 commit comments