Skip to content

Commit 65e75fd

Browse files
committed
model is coherent
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 930e35c commit 65e75fd

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

src/llmcompressor/observers/base.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,14 @@ def get_qparams(
162162
observed = observed.index_select(-1, perm)
163163

164164
# TODO: experiment with vectorizing for loop for performance
165-
# all reduce all dims except the last one
165+
# all reduce all dims except the second to last one
166166
end = 0
167167
for group_index in range(num_groups):
168168
start = end
169169
end = start + group_size
170170
scale, zero_point = self.get_qparams_along_dim(
171171
observed[..., start:end],
172-
dim=tuple(range(observed.ndim - 1)),
172+
dim=-2,
173173
tensor_id=group_index,
174174
global_scale=global_scale,
175175
)
@@ -178,17 +178,15 @@ def get_qparams(
178178
self._zero_point[:, group_index] = zero_point.squeeze(1)
179179

180180
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
181-
# all reduce all dims except the last one
182-
self._scale, self._zero_point = self.get_qparams_along_dim(
183-
observed,
184-
dim=tuple(range(observed.ndim - 1)),
185-
)
181+
# all reduce all dims except the second to last one
182+
self._scale, self._zero_point = self.get_qparams_along_dim(observed, -2)
186183

187184
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
188-
# all reduce all dims except the last one
185+
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
186+
# should be batch, token
189187
self._scale, self._zero_point = self.get_qparams_along_dim(
190188
observed,
191-
dim=tuple(range(observed.ndim - 1)),
189+
dim={0, 1},
192190
)
193191

194192
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
@@ -246,15 +244,23 @@ def get_qparams(
246244

247245
def get_qparams_along_dim(
248246
self,
249-
observed,
247+
observed: torch.Tensor,
250248
dim: Union[int, Iterable[int]],
251249
tensor_id: Optional[Any] = None,
252250
global_scale: Optional[Tensor] = None,
253251
):
252+
# cast to set
254253
if isinstance(dim, int):
255254
dim = [dim]
256255
dim = set(dim)
257256

257+
# convert negative dims
258+
dim = [
259+
d if d >= 0 else observed.ndim + d
260+
for d in dim
261+
]
262+
263+
# reduce all dimensions except the the one pass as argument to this function
258264
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
259265
return self.calculate_qparams(
260266
observed,

0 commit comments

Comments
 (0)