Skip to content

Commit 59b88bf

Browse files
committed
fix: use update_deferred_stats for all observer types
MemorylessMinMaxObserver has no past_min_vals, so get_accumulated_min_max() always returned None, causing scale to remain 0. Fix: add update_deferred_stats() to Observer base class which maintains _deferred_min/_deferred_max independently of subclass implementation. calibrate_activations(stats_only=True) now calls this instead of observer(value). Local validation on opt-125m (CPU, 32 calibration samples): - 72/72 modules have input_scale - Perplexity: 28.86 (FP32) -> 30.78 (INT8), 6.7% degradation - No observer stats leaked after calibration
1 parent 49e7208 commit 59b88bf

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,18 @@ def calibrate_activations(
191191
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
192192
calculate_gparam = True
193193

194-
# In deferred (stats_only) mode, only accumulate running min/max in the
195-
# observer — skip writing scale/zero_point until epoch end.
194+
# In deferred (stats_only) mode: call the observer to accumulate running
195+
# min/max stats but do NOT write scale/zero_point yet.
196+
# Qparams are written once at epoch end via flush_activation_qparams.
196197
if stats_only:
197-
calculate_qparams = False
198-
calculate_gparam = False
198+
# Deferred mode: accumulate global min/max into the observer's
199+
# _deferred_min / _deferred_max. Works for ALL observer types,
200+
# including MemorylessMinMaxObserver which has no past_min_vals.
201+
# Qparams are written once at epoch end via flush_activation_qparams.
202+
observer = getattr(module, f"{base_name}_observer", None)
203+
if observer is not None:
204+
observer.update_deferred_stats(value)
205+
return
199206

200207
call_observer(
201208
module=module,

src/llmcompressor/observers/base.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,31 +74,49 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
7474
"""
7575
raise NotImplementedError()
7676

77+
def update_deferred_stats(self, observed: torch.Tensor):
78+
"""
79+
Accumulate global min/max from an observed tensor into ``_deferred_min``
80+
and ``_deferred_max`` on this observer.
81+
82+
Called by ``calibrate_activations`` in ``stats_only`` mode for ALL observer
83+
types including ``MemorylessMinMaxObserver`` which has no ``past_min_vals``.
84+
85+
:param observed: activation tensor for this batch
86+
"""
87+
batch_min = observed.float().min()
88+
batch_max = observed.float().max()
89+
90+
if not hasattr(self, "_deferred_min") or self._deferred_min is None:
91+
self._deferred_min = batch_min
92+
self._deferred_max = batch_max
93+
else:
94+
self._deferred_min = torch.min(self._deferred_min, batch_min)
95+
self._deferred_max = torch.max(self._deferred_max, batch_max)
96+
7797
def get_accumulated_min_max(self) -> Optional[MinMaxTuple]:
7898
"""
79-
Return the accumulated running min/max statistics stored by this observer,
80-
without performing any new observation. Returns None if no statistics have
81-
been accumulated yet (i.e. no batches have been seen).
99+
Return accumulated min/max populated by ``update_deferred_stats``.
100+
Returns None if no batches have been seen yet.
82101
83-
Subclasses which accumulate state (StaticMinMax, MovingAverage) naturally
84-
expose this through their ``past_min_vals`` / ``past_max_vals`` attributes.
85-
Memoryless observers have no running state, so this always returns None.
102+
Works for all observer types including ``MemorylessMinMaxObserver``.
86103
87104
:return: (min_vals, max_vals) tensors or None
88105
"""
89-
min_vals = getattr(self, "past_min_vals", None)
90-
max_vals = getattr(self, "past_max_vals", None)
106+
min_vals = getattr(self, "_deferred_min", None)
107+
max_vals = getattr(self, "_deferred_max", None)
91108
if min_vals is None or max_vals is None:
92109
return None
93110
return min_vals, max_vals
94111

95112
def clear_accumulated_stats(self):
96113
"""
97114
Delete accumulated running statistics to free memory after qparams have been
98-
computed and written to the parent module. Only clears attributes that exist
99-
on the observer (memoryless observers are unaffected).
115+
computed and written to the parent module.
100116
"""
101117
for attr in (
118+
"_deferred_min",
119+
"_deferred_max",
102120
"past_min_vals",
103121
"past_max_vals",
104122
"past_global_min_vals",

0 commit comments

Comments
 (0)