@@ -63,12 +63,18 @@ def calculate_qparams(
63
63
self ,
64
64
observed : Tensor ,
65
65
reduce_dims : Optional [Tuple [int ]] = None ,
66
+ tensor_id : Optional [Any ] = None ,
67
+ global_scale : Optional [Tensor ] = None ,
66
68
) -> Tuple [FloatTensor , IntTensor ]:
67
69
"""
68
70
:param observed: observed tensor to calculate quantization parameters for
69
71
:param reduce_dims: optional tuple of dimensions to reduce along,
70
72
returned scale and zero point will be shaped (1,) along the
71
73
reduced dimensions
74
+ :param tensor_id: optional id for tracking separate statistics when different
75
+ ranges of observed tensors are passed, useful for sharding tensors by
76
+ group_size or block quantization
77
+ :param global_scale: optional scale to further scale local quantization scales
72
78
:return: tuple of scale and zero point derived from the observed tensor
73
79
"""
74
80
raise NotImplementedError (f"{ self .__class__ } must implement calculate_qparams" )
@@ -233,8 +239,12 @@ def get_qparams(
233
239
c0 = j * block_cols
234
240
c1 = min ((j + 1 ) * block_cols , cols )
235
241
# reduce across both dims to get one scale and zp per block
242
+ # Use unique tensor_id for each block to maintain separate stats
243
+ block_tensor_id = f"block_{ i } _{ j } "
236
244
scale_bp , zp_bp = self .calculate_qparams (
237
- observed [r0 :r1 , c0 :c1 ], reduce_dims = (0 , 1 )
245
+ observed [r0 :r1 , c0 :c1 ],
246
+ reduce_dims = (0 , 1 ),
247
+ tensor_id = block_tensor_id ,
238
248
)
239
249
self ._scale [i , j ] = scale_bp
240
250
self ._zero_point [i , j ] = zp_bp
0 commit comments