Skip to content

Commit dd91329

Browse files
committed
revert global disjointness
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 05d780a commit dd91329

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def update_weight_global_scale(module: Module):
147147
should_calculate_gparam=True,
148148
should_calculate_qparams=False,
149149
)
150+
module.weight_observer.reset()
150151

151152

152153
def update_weight_zp_scale(module: Module):

src/llmcompressor/observers/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def forward(
5151
:return: tuple of scale and zero point based on last observed value
5252
"""
5353
self.record_observed_tokens(observed)
54+
5455
if should_calculate_gparam:
56+
# NOTE: this function updates running min/max values, which leads to
57+
# running values updating twice
5558
return self.get_gparam(observed=observed)
59+
5660
return self.get_qparams(
5761
observed=observed,
5862
g_idx=g_idx,

src/llmcompressor/observers/min_max.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
55
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
6-
from compressed_tensors.utils import deprecated, patch_attr
6+
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
99

@@ -87,12 +87,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
8787
:param observed: observed tensor to calculate quantization parameters for
8888
:return: updated global scale derived from the observed tensor
8989
"""
90-
91-
# patch to avoid affecting running means
92-
with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}):
93-
updated_min_val, updated_max_val = self.calculate_updated_min_max(
94-
observed=observed
95-
)
90+
# NOTE: this function updates running min/max values, which leads to
91+
# running values updating twice
92+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
93+
observed=observed
94+
)
9695
return generate_gparam(
9796
updated_min_val=updated_min_val, updated_max_val=updated_max_val
9897
)

tests/llmcompressor/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def check_for_created_files():
6161
f"Created files: {set(end_files_root) - set(start_files_root)}"
6262
)
6363

64-
max_allowed_sized_temp_files_megabytes = 1.5
64+
max_allowed_sized_temp_files_megabytes = 1
6565
end_files_temp = _get_files(directory=tempfile.gettempdir())
6666
created_temp_files = set(end_files_temp) - set(start_files_temp)
6767
# pytest temp files are automatically deleted, exclude from size calculation

0 commit comments

Comments
 (0)