Skip to content

Commit 1138be5

Browse files
committed
better gparam, finish tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ad5f719 commit 1138be5

File tree

3 files changed

+157
-103
lines changed

3 files changed

+157
-103
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module):
147147
should_calculate_gparam=True,
148148
should_calculate_qparams=False,
149149
)
150-
module.weight_observer.reset()
151150

152151

153152
def update_weight_zp_scale(module: Module):

src/llmcompressor/observers/min_max.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Any, Optional, Tuple, Union, Iterable
1+
from typing import Any, Iterable, Optional, Tuple, Union
22

33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
55
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
6-
from compressed_tensors.utils import deprecated
6+
from compressed_tensors.utils import deprecated, patch_attr
77

88
from llmcompressor.observers.base import Observer
99

@@ -88,9 +88,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
8888
:return: updated global scale derived from the observed tensor
8989
"""
9090

91-
updated_min_val, updated_max_val = self.calculate_updated_min_max(
92-
observed=observed
93-
)
91+
# patch to avoid affecting running means
92+
with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}):
93+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
94+
observed=observed
95+
)
9496
return generate_gparam(
9597
updated_min_val=updated_min_val, updated_max_val=updated_max_val
9698
)

0 commit comments

Comments
 (0)