Skip to content

Commit a6b3842

Browse files
committed
remove safe permute
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dd91329 commit a6b3842

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/llmcompressor/observers/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
)
1111
from compressed_tensors.quantization.utils import is_fp4
1212
from compressed_tensors.registry.registry import RegistryMixin
13-
from compressed_tensors.utils import safe_permute
1413
from loguru import logger
1514
from torch import FloatTensor, IntTensor, Tensor
1615

@@ -56,7 +55,7 @@ def forward(
5655
# NOTE: this function updates running min/max values, which leads to
5756
# running values updating twice
5857
return self.get_gparam(observed=observed)
59-
58+
6059
return self.get_qparams(
6160
observed=observed,
6261
g_idx=g_idx,
@@ -172,8 +171,7 @@ def get_qparams(
172171
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
173172
group_sizes = group_sizes[torch.argsort(group_indices)]
174173

175-
perm = torch.argsort(g_idx)
176-
observed = safe_permute(observed, perm, dim=1)
174+
observed = observed.index_select(g_idx, -1)
177175

178176
# TODO: experiment with vectorizing for loop for performance
179177
end = 0

0 commit comments

Comments
 (0)