@@ -162,14 +162,14 @@ def get_qparams(
162
162
observed = observed .index_select (- 1 , perm )
163
163
164
164
# TODO: experiment with vectorizing for loop for performance
165
- # all reduce all dims except the last one
165
+ # all reduce all dims except the second to last one
166
166
end = 0
167
167
for group_index in range (num_groups ):
168
168
start = end
169
169
end = start + group_size
170
170
scale , zero_point = self .get_qparams_along_dim (
171
171
observed [..., start :end ],
172
- dim = tuple ( range ( observed . ndim - 1 )) ,
172
+ dim = - 2 ,
173
173
tensor_id = group_index ,
174
174
global_scale = global_scale ,
175
175
)
@@ -178,17 +178,15 @@ def get_qparams(
178
178
self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
179
179
180
180
elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
181
- # all reduce all dims except the last one
182
- self ._scale , self ._zero_point = self .get_qparams_along_dim (
183
- observed ,
184
- dim = tuple (range (observed .ndim - 1 )),
185
- )
181
+ # all reduce all dims except the second to last one
182
+ self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , - 2 )
186
183
187
184
elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
188
- # all reduce all dims except the last one
185
+ # use dim 1, assume the obsersed.shape = [batch, token, hidden]
186
+ # should be batch, token
189
187
self ._scale , self ._zero_point = self .get_qparams_along_dim (
190
188
observed ,
191
- dim = tuple ( range ( observed . ndim - 1 )) ,
189
+ dim = { 0 , 1 } ,
192
190
)
193
191
194
192
elif self .quantization_args .strategy == QuantizationStrategy .BLOCK :
@@ -246,15 +244,23 @@ def get_qparams(
246
244
247
245
def get_qparams_along_dim (
248
246
self ,
249
- observed ,
247
+ observed : torch . Tensor ,
250
248
dim : Union [int , Iterable [int ]],
251
249
tensor_id : Optional [Any ] = None ,
252
250
global_scale : Optional [Tensor ] = None ,
253
251
):
252
+ # cast to set
254
253
if isinstance (dim , int ):
255
254
dim = [dim ]
256
255
dim = set (dim )
257
256
257
+ # convert negative dims
258
+ dim = [
259
+ d if d >= 0 else observed .ndim + d
260
+ for d in dim
261
+ ]
262
+
263
+ # reduce all dimensions except the the one pass as argument to this function
258
264
reduce_dims = tuple (idx for idx in range (observed .ndim ) if idx not in dim )
259
265
return self .calculate_qparams (
260
266
observed ,
0 commit comments