Skip to content

Commit da7ef54

Browse files
committed
refactor: address HDCharles review comments
- rename flush_activation_qparams -> write_activation_qparams - rename calibrate_module_from_observer -> update_module_qparams_from_observer - extract ACTIVATION_BASE_NAMES constant in calibration.py - move SEQUENTIAL_EPOCH_END docstring note from on_start to on_event - use ExitStack for propagation pass quantization management - update observer.forward() to accumulate stats alongside computing qparams
1 parent 316114a commit da7ef54

File tree

4 files changed

+29
-23
lines changed

4 files changed

+29
-23
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.nn import Module
1818

1919
from llmcompressor.observers import Observer
20-
from llmcompressor.observers.base import calibrate_module_from_observer
20+
from llmcompressor.observers.base import update_module_qparams_from_observer
2121

2222
__all__ = [
2323
"initialize_observer",
@@ -31,9 +31,12 @@
3131
"calibrate_query_hook",
3232
"calibrate_key_hook",
3333
"calibrate_value_hook",
34-
"flush_activation_qparams",
34+
"write_activation_qparams",
3535
]
3636

37+
# Activation observer base names used across calibration and quantization code
38+
ACTIVATION_BASE_NAMES = ("input", "output", "q", "k", "v")
39+
3740

3841
def initialize_observer(
3942
module: Module,
@@ -171,7 +174,7 @@ def calibrate_activations(
171174
:param stats_only: if True, only update running statistics in the observer
172175
(accumulate min/max) without computing or writing scale/zero_point.
173176
Used during deferred qparam calibration — qparams are computed once
174-
at epoch end via flush_activation_qparams instead of per batch.
177+
at epoch end via write_activation_qparams instead of per batch.
175178
"""
176179
# If empty tensor, can't update zp/scale
177180
# Case for MoEs
@@ -193,7 +196,7 @@ def calibrate_activations(
193196

194197
# In deferred (stats_only) mode: call the observer to accumulate running
195198
# min/max stats but do NOT write scale/zero_point yet.
196-
# Qparams are written once at epoch end via flush_activation_qparams.
199+
# Qparams are written once at epoch end via write_activation_qparams.
197200
if stats_only:
198201
observer = getattr(module, f"{base_name}_observer", None)
199202
if observer is not None:
@@ -213,7 +216,7 @@ def calibrate_input_hook(module: Module, args: Any):
213216
"""
214217
Hook to accumulate input activation statistics (min/max) in the observer.
215218
Scale and zero_point are not written here; they are computed once per subgraph
216-
at epoch end via flush_activation_qparams.
219+
at epoch end via write_activation_qparams.
217220
"""
218221
args = args[0] if isinstance(args, tuple) else args
219222
calibrate_activations(module, value=args, base_name="input", stats_only=True)
@@ -223,7 +226,7 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
223226
"""
224227
Hook to accumulate output activation statistics (min/max) in the observer.
225228
Scale and zero_point are not written here; they are computed once per subgraph
226-
at epoch end via flush_activation_qparams.
229+
at epoch end via write_activation_qparams.
227230
Note: forward_quantize is intentionally absent — hooks only collect statistics.
228231
"""
229232
calibrate_activations(
@@ -287,7 +290,7 @@ def reset_quantization_status(model: Module):
287290
delattr(module, "quantization_status")
288291

289292

290-
def flush_activation_qparams(module: Module):
293+
def write_activation_qparams(module: Module):
291294
"""
292295
Compute and write final activation qparams from each observer's accumulated
293296
running statistics, then free those statistics to reduce memory.
@@ -301,13 +304,13 @@ def flush_activation_qparams(module: Module):
301304
302305
apply to targeted modules with:
303306
for _, module in match_named_modules(...):
304-
flush_activation_qparams(module)
307+
write_activation_qparams(module)
305308
306309
:param module: module to flush activation qparams for
307310
"""
308311
scheme = getattr(module, "quantization_scheme", None)
309312
if scheme is None:
310313
return
311314

312-
for base_name in ("input", "output", "q", "k", "v"):
313-
calibrate_module_from_observer(module, base_name)
315+
for base_name in ACTIVATION_BASE_NAMES:
316+
update_module_qparams_from_observer(module, base_name)

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from llmcompressor.core import Event, EventType, State
55
from llmcompressor.modifiers import Modifier
66
from llmcompressor.modifiers.quantization.calibration import (
7-
flush_activation_qparams,
7+
write_activation_qparams,
88
update_weight_global_scale,
99
update_weight_zp_scale,
1010
)
@@ -67,8 +67,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6767
def on_start(self, state: State, event: Event, **kwargs):
6868
"""
6969
Begin calibrating activations and weights. Calibrate weights only once on start.
70-
Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END via
71-
flush_activation_qparams, rather than per batch.
7270
"""
7371
self.started_ = True
7472
QuantizationMixin.start_calibration(self, state.model)
@@ -99,12 +97,13 @@ def on_event(self, state: State, event: Event, **kwargs):
9997
self.on_start(state, None)
10098

10199
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
102-
# Compute scale/zero_point once from accumulated running statistics,
103-
# then free those stats to reduce memory.
100+
# Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END
101+
# from accumulated running statistics, rather than per batch.
102+
# Running statistics are freed after qparams are written to reduce memory.
104103
for _, module in match_named_modules(
105104
state.model, self.resolved_targets, self.ignore
106105
):
107-
flush_activation_qparams(module)
106+
write_activation_qparams(module)
108107

109108
if event.type_ == EventType.CALIBRATION_EPOCH_END:
110109
if not self.ended_:

src/llmcompressor/observers/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from compressed_tensors.utils import align_module_device, update_offload_parameter
1111
from llmcompressor.observers.helpers import flatten_for_calibration
1212

13-
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"]
13+
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "update_module_qparams_from_observer"]
1414

1515
MinMaxTuple = Tuple[torch.Tensor, torch.Tensor]
1616
ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor]
@@ -127,12 +127,14 @@ def clear_accumulated_stats(self):
127127
@torch.no_grad
128128
def forward(self, observed: torch.Tensor) -> ScaleZpTuple:
129129
"""
130-
Calculate updated scales and zero points from observed value
131-
(weight, activation, or attention state).
130+
Accumulate running statistics from the observed value and update
131+
deferred min/max. Qparams (scale/zero_point) are not computed here;
132+
they are written once at epoch end via update_module_qparams_from_observer.
132133
133134
:param observed: value being observed
134-
:return: calibrated scale and zero point
135+
:return: calibrated scale and zero point (from accumulated stats)
135136
"""
137+
self.update_deferred_stats(observed)
136138
scales, zero_points, _min, _max = self._forward_with_minmax(observed)
137139
return (scales, zero_points)
138140

@@ -195,7 +197,7 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
195197

196198

197199
@torch.no_grad()
198-
def calibrate_module_from_observer(
200+
def update_module_qparams_from_observer(
199201
module: torch.nn.Module,
200202
base_name: str,
201203
) -> bool:

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def __call__(
160160
# propagation pass: modifier hooks are disabled but quantization is
161161
# re-enabled so that compressed module outputs are quantized.
162162
# This ensures downstream subgraphs receive realistic inputs.
163-
model.apply(enable_quantization)
164-
with HooksMixin.disable_hooks():
163+
with contextlib.ExitStack() as prop_stack:
164+
prop_stack.enter_context(HooksMixin.disable_hooks())
165+
model.apply(enable_quantization)
165166
for batch_idx, inputs in _get_batches(
166167
activations,
167168
num_batches,
@@ -173,6 +174,7 @@ def __call__(
173174
if subgraph_index < num_subgraphs - 1:
174175
activations.update(batch_idx, output)
175176
activations.delete(batch_idx, subgraph.consumed_names)
177+
# restore disabled quantization for next calibration pass
176178
model.apply(disable_quantization)
177179

178180
# redundant, finish any remaining compression

0 commit comments

Comments
 (0)